-
Notifications
You must be signed in to change notification settings - Fork 373
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feature(davide): GTrXL implementation #136
Conversation
Implementation of GTrXL. Still work in progress.
Codecov Report
@@ Coverage Diff @@
## main #136 +/- ##
==========================================
- Coverage 86.36% 86.05% -0.32%
==========================================
Files 461 464 +3
Lines 35168 35825 +657
==========================================
+ Hits 30373 30829 +456
- Misses 4795 4996 +201
Flags with carried forward coverage won't be shown. Click here to find out more.
Continue to review full report at Codecov.
|
…I-engine into GTrXL-implementation
ding/torch_utils/network/gtrxl.py
Outdated
.. note:: | ||
Adapted from https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/mem_transformer.py | ||
""" | ||
def __init__(self, embedding_dim): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add python typing lint
Compute positional embedding | ||
Arguments: | ||
- pos_seq: (:obj:`torch.Tensor`): positional sequence, | ||
usually a 1D integer sequence as [seq_len-1, seq_len-2, ..., 1, 0], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why here is the order from largest to smallest
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for pointing out this, both the original implementation https://github.com/kimiyoung/transformer-xl and the huggingface implementation https://github.com/huggingface/transformers/ of TransformerXL code the position in this way, I also tried the incremental order but the model didn't learn. I will further look into the problem.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think here the order is the same, I found both orders in different Transformer implementations
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why the case incremental order but the model didn't learn
happened?
ding/torch_utils/network/gtrxl.py
Outdated
# For position embedding, the order of sin/cos is negligible. | ||
# This is because tokens are consumed by the matrix multiplication which is permutation-invariant. | ||
pos_embedding = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1) | ||
return pos_embedding[:, None, :] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is easier to understand to use unsqueeze
here
ding/torch_utils/network/gtrxl.py
Outdated
self.Uz = torch.nn.Linear(input_dim, input_dim) | ||
self.Wg = torch.nn.Linear(input_dim, input_dim) | ||
self.Ug = torch.nn.Linear(input_dim, input_dim) | ||
self.bg = nn.Parameter(torch.zeros(input_dim).fill_(bg)) # bias |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch.full
can be better
ding/torch_utils/network/gtrxl.py
Outdated
Overview: | ||
GRU Gating Unit used in GTrXL | ||
""" | ||
def __init__(self, input_dim, bg=0.2): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
give some comments for why set bg=0.2
default
ding/torch_utils/network/gtrxl.py
Outdated
if memory is None: | ||
self.reset(bs) # (layer_num+1) x memory_len x batch_size x embedding_dim | ||
elif memory.shape[-2] != bs or memory.shape[-1] != self.embedding_dim: | ||
print("Memory {} and Input {} dimensions don't match," |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use one_time_warning
""" | ||
if self.memory is None or hidden_state is None: | ||
return None | ||
sequence_len = hidden_state[0].shape[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should keep suspicious of input arguments in computation graph, so we need to detach all the hidden state first, like this:
detached_h = [h.clone().detach() for h in hidden_state]
, and then just use no_grad
context in following codes without any detach
self.layers = nn.Sequential(*layers) | ||
self.embedding_dim = embedding_dim | ||
# u and v are the parameters to compute global content bias and global positional bias | ||
self.u, self.v = ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how to initialize u and v
ding/torch_utils/network/gtrxl.py
Outdated
torch.triu( | ||
torch.ones((cur_seq, cur_seq + prev_seq)), | ||
diagonal=1 + prev_seq, | ||
).bool()[..., None].to(x.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use unsqueeze
…I-engine into GTrXL-implementation
* Create gtrxl.py Implementation of GTrXL. Still work in progress. * first working version of GTrXL * added variable memory lenght * fixed some details in attention * Delete test.py * Delete temp.py * added comments * Update gtrxl.py * Update gtrxl.py * minor changes * merged the 2 forward functions * added tranformer wrapper and test * modified a param in gru gate * modified transformer wrapper * Update gtrxl.py * changed wrapper logic * Refactored memory class * added config and policy for cartpole (WIP) * added test gtrxl * many updates * Update r2d2_gtrxl.py * fixed bug in transformer wraper * in this version gtrxl can learn but still not converge * converge cartpole * fixed according to comments * added some tests and 2 params * increased attention speed * better way to handle memory * added segmentation in training * simplified rel_shift * add lunarlander and bsuite config * add burnin_step * option to choose init memory method * improved transformer input wrapper * added memory wrapper * finished all 3 wrappers * Update r2d2_gtrxl.py * Update model_wrappers.py * polish code * fixed cuda memory bug * Update cartpole_r2d2_gtrxl_config.py * atari support * Update lunarlander_r2d2_config.py * fixed bug in 3d obs * Create spaceinvaders_r2d2_gtrxl_config.py * updated pong and space conf * updated pong and space configs * cartpole best config * more test wrappers and comments * add memory wrapper in eval * Update gtrxl.py * Create qbert_r2d2_gtrxl_config.py * lunarlander and qbert * solve conficts * solve conflicts * format problems * conflicts * added some files back * Update spaceinvaders_r2d2_gtrxl_config.py * best conf lunarlander * Update pong_r2d2_gtrxl_config.py * best configs cartpole, pong, lunarlander * format code * add unit test * updated configs * update bsuite env and test
* fix/fix_submodule_err (opendilab#61) * fix/fix_submodule_err --------- Co-authored-by: ChenQiaoling00 <qiaoling_chen@u.nus.edu> * fix issue templates (opendilab#65) * fix(tokenizer): refactor tokenizer and update usage in readme (opendilab#51) * update tokenizer example * fix(readme, requirements): fix typo at Chinese readme and select a lower version of transformers (opendilab#73) * fix a typo in readme * in order to find InternLMTokenizer, select a lower version of Transformers --------- Co-authored-by: gouhchangjiang <gouhchangjiang@gmail.com> * [Doc] Add wechat and discord link in readme (opendilab#78) * Doc:add wechat and discord link * Doc:update wechat and discord link * Doc:update wechat and discord link * Doc:update wechat and discord link * Doc:update wechat and discord link * Doc:update wechat and discord link * Doc:update wechat and discord link * Doc:update wechat and discord link * Doc:update wechat and discord link * Doc:update wechat and discord link * Doc:update wechat and discord link * [Docs]: add Japanese README (opendilab#43) * Add Japanese README * Update README-ja-JP.md replace message * Update README-ja-JP.md * add repetition_penalty in GenerationConfig in web_demo.py (opendilab#48) Co-authored-by: YWMditto <862779238@qq.com> * use fp16 in instruction (opendilab#80) * [Enchancement] add more options for issue template (opendilab#77) * [Enchancement] add more options for issue template * update qustion icon * fix link * Use tempfile for convert2hf.py (opendilab#23) Fix InternLM/InternLM#50 * delete torch_dtype of README's example code (opendilab#100) * set the value of repetition_penalty to 1.0 to avoid random outputs (opendilab#99) * Update web_demo.py (opendilab#97) Remove meaningless log. * [Fix]Fix wrong string cutoff in the script for sft text tokenizing (opendilab#106) * docs(install.md): update dependency package transformers version to >= 4.28.0 (opendilab#124) Co-authored-by: 黄婷 <huangting3@CN0014010744M.local> * docs(LICENSE): add license (opendilab#125) * add license of colossalai and flash-attn * fix lint * modify the name * fix AutoModel map in convert2hf.py (opendilab#116) * variables are not printly as expect (opendilab#114) * feat(solver): fix code to adapt to torch2.0 and provide docker images (opendilab#128) * feat(solver): fix code to adapt to torch2.0 * docs(install.md): publish internlm environment image * docs(install.md): update dependency packages version * docs(install.md): update default image --------- Co-authored-by: 黄婷 <huangting3@CN0014010744M.local> * add demo test (opendilab#132) Co-authored-by: qa-caif-cicd <qa-caif-cicd@pjlab.org.cn> * fix web_demo cache accelerate (opendilab#133) * Doc: add twitter link (opendilab#141) * Feat add checkpoint fraction (opendilab#151) * feat(config): add checkpoint_fraction into config * feat: remove checkpoint_fraction from configs/7B_sft.py --------- Co-authored-by: wangguoteng.p <wangguoteng925@qq.com> * [Doc] update deployment guide to keep consistency with lmdeploy (opendilab#136) * update deployment guide * fix error * use llm partition (opendilab#159) Co-authored-by: qa-caif-cicd <qa-caif-cicd@pjlab.org.cn> * test(ci_scripts): clean test data after test, remove unnecessary global variables, and other optimizations (opendilab#165) * test: optimization of ci scripts(variables, test data cleaning, etc). * chore(workflows): disable ci job on push. * fix: update partition * test(ci_scripts): add install requirements automaticlly,trigger event about lint check and other optimizations (opendilab#174) * add pull_request in lint check * use default variables in ci_scripts * fix format * check and install requirements automaticlly * fix format --------- Co-authored-by: qa-caif-cicd <qa-caif-cicd@pjlab.org.cn> * feat(profiling): add a simple memory profiler (opendilab#89) * feat(profiling): add simple memory profiler * feat(profiling): add profiling argument * feat(CI_workflow): Add PR & Issue auto remove workflow (opendilab#184) * feat(ci_workflow): Add PR & Issue auto remove workflow Add a workflow for stale PR & Issue auto remove - pr & issue well be labeled as stale for inactive in 7 days - staled PR & Issue well be remove in 7 days - run this workflow every day on 1:30 a.m. * Update stale.yml * feat(bot): Create .owners.yml for Auto Assign (opendilab#176) * Create .owners.yml: for issue/pr assign automatically * Update .owners.yml * Update .owners.yml fix typo * [feat]: add pal reasoning script (opendilab#163) * [Feat] Add PAL inference script * Update README.md * Update tools/README.md Co-authored-by: BigDong <yudongwang1226@gmail.com> * Update tools/pal_inference.py Co-authored-by: BigDong <yudongwang1226@gmail.com> * Update pal script * Update README.md * restore .ore-commit-config.yaml * Update tools/README.md Co-authored-by: BigDong <yudongwang1226@gmail.com> * Update tools/README.md Co-authored-by: BigDong <yudongwang1226@gmail.com> * Update pal inference script * Update READMD.md * Update internlm/utils/interface.py Co-authored-by: Wenwei Zhang <40779233+ZwwWayne@users.noreply.github.com> * Update pal script * Update pal script * Update script * Add docstring * Update format * Update script * Update script * Update script --------- Co-authored-by: BigDong <yudongwang1226@gmail.com> Co-authored-by: Wenwei Zhang <40779233+ZwwWayne@users.noreply.github.com> * test(ci_scripts): add timeout settings and clean work after the slurm job (opendilab#185) * restore pr test on develop branch * add mask * add post action to cancel slurm job * remove readonly attribute on job log * add debug info * debug job log * try stdin * use stdin * set default value avoid error * try setting readonly on job log * performance echo * remove debug info * use squeue to check slurm job status * restore the lossed parm * litmit retry times * use exclusive to avoid port already in use * optimize loop body * remove partition * add {} for variables * set env variable for slurm partition --------- Co-authored-by: qa-caif-cicd <qa-caif-cicd@pjlab.org.cn> * refactor(tools): move interface.py and import it to web_demo (opendilab#195) * move interface.py and import it to web_demo * typo * fix(ci): fix lint error * fix(ci): fix lint error --------- Co-authored-by: Sun Peng <sunpengsdu@gmail.com> Co-authored-by: ChenQiaoling00 <qiaoling_chen@u.nus.edu> Co-authored-by: Kai Chen <chenkaidev@gmail.com> Co-authored-by: Yang Gao <Gary1546308416AL@gmail.com> Co-authored-by: Changjiang GOU <gouchangjiang@gmail.com> Co-authored-by: gouhchangjiang <gouhchangjiang@gmail.com> Co-authored-by: vansin <msnode@163.com> Co-authored-by: Ikko Eltociear Ashimine <eltociear@gmail.com> Co-authored-by: YWMditto <46778265+YWMditto@users.noreply.github.com> Co-authored-by: YWMditto <862779238@qq.com> Co-authored-by: WRH <12756472+wangruohui@users.noreply.github.com> Co-authored-by: liukuikun <24622904+Harold-lkk@users.noreply.github.com> Co-authored-by: x54-729 <45304952+x54-729@users.noreply.github.com> Co-authored-by: Shuo Zhang <zhangshuolove@live.com> Co-authored-by: Miao Zheng <76149310+MeowZheng@users.noreply.github.com> Co-authored-by: 黄婷 <huangting3@CN0014010744M.local> Co-authored-by: ytxiong <45058324+yingtongxiong@users.noreply.github.com> Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> Co-authored-by: kkscilife <126147887+kkscilife@users.noreply.github.com> Co-authored-by: qa-caif-cicd <qa-caif-cicd@pjlab.org.cn> Co-authored-by: hw <45089338+MorningForest@users.noreply.github.com> Co-authored-by: Guoteng <32697156+SolenoidWGT@users.noreply.github.com> Co-authored-by: wangguoteng.p <wangguoteng925@qq.com> Co-authored-by: lvhan028 <lvhan_028@163.com> Co-authored-by: zachtzy <141206206+zachtzy@users.noreply.github.com> Co-authored-by: cx <759046501@qq.com> Co-authored-by: Jaylin Lee <61487970+APX103@users.noreply.github.com> Co-authored-by: del-zhenwu <dele.zhenwu@gmail.com> Co-authored-by: Shaoyuan Xie <66255889+Daniel-xsy@users.noreply.github.com> Co-authored-by: BigDong <yudongwang1226@gmail.com> Co-authored-by: Wenwei Zhang <40779233+ZwwWayne@users.noreply.github.com> Co-authored-by: huangting4201 <huangting3@sensetime.com>
Description
Implementation of the model GTrXL based on Transformer for RL.
Reference: https://arxiv.org/abs/1910.06764