Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(davide): GTrXL implementation #136

Merged
merged 76 commits into from
Mar 10, 2022

Conversation

davide97l
Copy link
Collaborator

Description

Implementation of the model GTrXL based on Transformer for RL.
Reference: https://arxiv.org/abs/1910.06764

davide97l and others added 2 commits December 2, 2021 13:47
@codecov
Copy link

codecov bot commented Dec 2, 2021

Codecov Report

Merging #136 (9e578fa) into main (e797618) will decrease coverage by 0.31%.
The diff coverage is 68.93%.

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #136      +/-   ##
==========================================
- Coverage   86.36%   86.05%   -0.32%     
==========================================
  Files         461      464       +3     
  Lines       35168    35825     +657     
==========================================
+ Hits        30373    30829     +456     
- Misses       4795     4996     +201     
Flag Coverage Δ
unittests 86.05% <68.93%> (-0.32%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
ding/policy/r2d2.py 89.93% <ø> (ø)
ding/policy/r2d2_gtrxl.py 17.81% <17.81%> (ø)
ding/model/template/q_learning.py 84.16% <20.00%> (-12.08%) ⬇️
ding/model/wrapper/model_wrappers.py 90.70% <74.31%> (-4.67%) ⬇️
ding/torch_utils/network/gtrxl.py 97.20% <97.20%> (ø)
ding/model/wrapper/test_model_wrappers.py 99.54% <100.00%> (+0.07%) ⬆️
ding/policy/__init__.py 100.00% <100.00%> (ø)
ding/policy/command_mode_policy_instance.py 94.31% <100.00%> (+0.13%) ⬆️
ding/torch_utils/network/__init__.py 100.00% <100.00%> (ø)
ding/torch_utils/network/tests/test_gtrxl.py 100.00% <100.00%> (ø)
... and 3 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update e797618...9e578fa. Read the comment docs.

@davide97l davide97l changed the title (davide) Gtrxl implementation WIP feature(davide): Gtrxl implementation Dec 2, 2021
@davide97l davide97l added the algo Add new algorithm or improve old one label Dec 2, 2021
.. note::
Adapted from https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/mem_transformer.py
"""
def __init__(self, embedding_dim):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add python typing lint

Compute positional embedding
Arguments:
- pos_seq: (:obj:`torch.Tensor`): positional sequence,
usually a 1D integer sequence as [seq_len-1, seq_len-2, ..., 1, 0],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why here is the order from largest to smallest

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for pointing out this, both the original implementation https://github.com/kimiyoung/transformer-xl and the huggingface implementation https://github.com/huggingface/transformers/ of TransformerXL code the position in this way, I also tried the incremental order but the model didn't learn. I will further look into the problem.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think here the order is the same, I found both orders in different Transformer implementations

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why the case incremental order but the model didn't learn happened?

# For position embedding, the order of sin/cos is negligible.
# This is because tokens are consumed by the matrix multiplication which is permutation-invariant.
pos_embedding = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1)
return pos_embedding[:, None, :]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is easier to understand to use unsqueeze here

self.Uz = torch.nn.Linear(input_dim, input_dim)
self.Wg = torch.nn.Linear(input_dim, input_dim)
self.Ug = torch.nn.Linear(input_dim, input_dim)
self.bg = nn.Parameter(torch.zeros(input_dim).fill_(bg)) # bias
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.full can be better

Overview:
GRU Gating Unit used in GTrXL
"""
def __init__(self, input_dim, bg=0.2):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

give some comments for why set bg=0.2 default

ding/torch_utils/network/gtrxl.py Show resolved Hide resolved
if memory is None:
self.reset(bs) # (layer_num+1) x memory_len x batch_size x embedding_dim
elif memory.shape[-2] != bs or memory.shape[-1] != self.embedding_dim:
print("Memory {} and Input {} dimensions don't match,"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use one_time_warning

"""
if self.memory is None or hidden_state is None:
return None
sequence_len = hidden_state[0].shape[0]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should keep suspicious of input arguments in computation graph, so we need to detach all the hidden state first, like this:
detached_h = [h.clone().detach() for h in hidden_state], and then just use no_grad context in following codes without any detach

self.layers = nn.Sequential(*layers)
self.embedding_dim = embedding_dim
# u and v are the parameters to compute global content bias and global positional bias
self.u, self.v = (
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how to initialize u and v

torch.triu(
torch.ones((cur_seq, cur_seq + prev_seq)),
diagonal=1 + prev_seq,
).bool()[..., None].to(x.device)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use unsqueeze

@PaParaZz1 PaParaZz1 changed the title WIP feature(davide): Gtrxl implementation feature(davide): GTrXL implementation Feb 17, 2022
@PaParaZz1 PaParaZz1 merged commit bcc8179 into opendilab:main Mar 10, 2022
puyuan1996 pushed a commit to puyuan1996/DI-engine that referenced this pull request Apr 18, 2022
* Create gtrxl.py

Implementation of GTrXL. Still work in progress.

* first working version of GTrXL

* added variable memory lenght

* fixed some details in attention

* Delete test.py

* Delete temp.py

* added comments

* Update gtrxl.py

* Update gtrxl.py

* minor changes

* merged the 2 forward functions

* added tranformer wrapper and test

* modified a param in gru gate

* modified transformer wrapper

* Update gtrxl.py

* changed wrapper logic

* Refactored memory class

* added config and policy for cartpole (WIP)

* added test gtrxl

* many updates

* Update r2d2_gtrxl.py

* fixed bug in transformer wraper

* in this version gtrxl can learn but still not converge

* converge cartpole

* fixed according to comments

* added some tests and 2 params

* increased attention speed

* better way to handle memory

* added segmentation in training

* simplified rel_shift

* add lunarlander and bsuite config

* add burnin_step

* option to choose init memory method

* improved transformer input wrapper

* added memory wrapper

* finished all 3 wrappers

* Update r2d2_gtrxl.py

* Update model_wrappers.py

* polish code

* fixed cuda memory bug

* Update cartpole_r2d2_gtrxl_config.py

* atari support

* Update lunarlander_r2d2_config.py

* fixed bug in 3d obs

* Create spaceinvaders_r2d2_gtrxl_config.py

* updated pong and space conf

* updated pong and space configs

* cartpole best config

* more test wrappers and comments

* add memory wrapper in eval

* Update gtrxl.py

* Create qbert_r2d2_gtrxl_config.py

* lunarlander and qbert

* solve conficts

* solve conflicts

* format problems

* conflicts

* added some files back

* Update spaceinvaders_r2d2_gtrxl_config.py

* best conf lunarlander

* Update pong_r2d2_gtrxl_config.py

* best configs cartpole, pong, lunarlander

* format code

* add unit test

* updated configs

* update bsuite env and test
SolenoidWGT added a commit to SolenoidWGT/DI-engine that referenced this pull request Aug 22, 2023
* fix/fix_submodule_err (opendilab#61)

* fix/fix_submodule_err

---------

Co-authored-by: ChenQiaoling00 <qiaoling_chen@u.nus.edu>

* fix issue templates (opendilab#65)

* fix(tokenizer): refactor tokenizer and update usage in readme (opendilab#51)

* update tokenizer example

* fix(readme, requirements): fix typo at Chinese readme and select a lower version of transformers (opendilab#73)

* fix a typo in readme

* in order to find InternLMTokenizer, select a lower version of Transformers

---------

Co-authored-by: gouhchangjiang <gouhchangjiang@gmail.com>

* [Doc] Add wechat and discord link in readme (opendilab#78)

* Doc:add wechat and discord link

* Doc:update wechat and discord link

* Doc:update wechat and discord link

* Doc:update wechat and discord link

* Doc:update wechat and discord link

* Doc:update wechat and discord link

* Doc:update wechat and discord link

* Doc:update wechat and discord link

* Doc:update wechat and discord link

* Doc:update wechat and discord link

* Doc:update wechat and discord link

* [Docs]: add Japanese README (opendilab#43)

* Add Japanese README

* Update README-ja-JP.md

replace message

* Update README-ja-JP.md

* add repetition_penalty in GenerationConfig in web_demo.py (opendilab#48)

Co-authored-by: YWMditto <862779238@qq.com>

* use fp16 in instruction (opendilab#80)

* [Enchancement] add more options for issue template (opendilab#77)

* [Enchancement] add more options for issue template

* update qustion icon

* fix link

* Use tempfile for convert2hf.py (opendilab#23)

Fix InternLM/InternLM#50

* delete torch_dtype of README's example code (opendilab#100)

* set the value of repetition_penalty to 1.0 to avoid random outputs (opendilab#99)

* Update web_demo.py (opendilab#97)

Remove meaningless log.

* [Fix]Fix wrong string cutoff in the script for sft text tokenizing (opendilab#106)

* docs(install.md): update dependency package transformers version to >= 4.28.0 (opendilab#124)

Co-authored-by: 黄婷 <huangting3@CN0014010744M.local>

* docs(LICENSE): add license (opendilab#125)

* add license of colossalai and flash-attn

* fix lint

* modify the name

* fix AutoModel map in convert2hf.py (opendilab#116)

* variables are not printly as expect (opendilab#114)

* feat(solver): fix code to adapt to torch2.0 and provide docker images (opendilab#128)

* feat(solver): fix code to adapt to torch2.0

* docs(install.md): publish internlm environment image

* docs(install.md): update dependency packages version

* docs(install.md): update default image

---------

Co-authored-by: 黄婷 <huangting3@CN0014010744M.local>

* add demo test (opendilab#132)

Co-authored-by: qa-caif-cicd <qa-caif-cicd@pjlab.org.cn>

* fix web_demo cache accelerate (opendilab#133)

* Doc: add twitter link (opendilab#141)

* Feat add checkpoint fraction (opendilab#151)

* feat(config): add checkpoint_fraction into config

* feat: remove checkpoint_fraction from configs/7B_sft.py

---------

Co-authored-by: wangguoteng.p <wangguoteng925@qq.com>

* [Doc] update deployment guide to keep consistency with lmdeploy (opendilab#136)

* update deployment guide

* fix error

* use llm partition (opendilab#159)

Co-authored-by: qa-caif-cicd <qa-caif-cicd@pjlab.org.cn>

* test(ci_scripts): clean test data after test, remove unnecessary global variables, and other optimizations (opendilab#165)

* test: optimization of ci scripts(variables, test data cleaning, etc).

* chore(workflows): disable ci job on push.

* fix: update partition

* test(ci_scripts): add install requirements automaticlly,trigger event about lint check and other optimizations (opendilab#174)

* add pull_request in lint check

* use default variables in ci_scripts

* fix format

* check and install requirements automaticlly

* fix format

---------

Co-authored-by: qa-caif-cicd <qa-caif-cicd@pjlab.org.cn>

* feat(profiling): add a simple memory profiler (opendilab#89)

* feat(profiling): add simple memory profiler

* feat(profiling): add profiling argument

* feat(CI_workflow): Add PR & Issue auto remove workflow (opendilab#184)

* feat(ci_workflow): Add PR & Issue auto remove workflow

Add a workflow for stale PR & Issue  auto remove
- pr & issue well be labeled as stale for inactive in 7 days
- staled PR & Issue  well be remove in 7 days
- run this workflow every day on 1:30 a.m.

* Update stale.yml

* feat(bot): Create .owners.yml for Auto Assign (opendilab#176)

* Create .owners.yml: for issue/pr assign automatically

* Update .owners.yml

* Update .owners.yml

fix typo

* [feat]: add pal reasoning script (opendilab#163)

* [Feat] Add PAL inference script

* Update README.md

* Update tools/README.md

Co-authored-by: BigDong <yudongwang1226@gmail.com>

* Update tools/pal_inference.py

Co-authored-by: BigDong <yudongwang1226@gmail.com>

* Update pal script

* Update README.md

* restore .ore-commit-config.yaml

* Update tools/README.md

Co-authored-by: BigDong <yudongwang1226@gmail.com>

* Update tools/README.md

Co-authored-by: BigDong <yudongwang1226@gmail.com>

* Update pal inference script

* Update READMD.md

* Update internlm/utils/interface.py

Co-authored-by: Wenwei Zhang <40779233+ZwwWayne@users.noreply.github.com>

* Update pal script

* Update pal script

* Update script

* Add docstring

* Update format

* Update script

* Update script

* Update script

---------

Co-authored-by: BigDong <yudongwang1226@gmail.com>
Co-authored-by: Wenwei Zhang <40779233+ZwwWayne@users.noreply.github.com>

* test(ci_scripts): add timeout settings and clean work after the slurm job (opendilab#185)

* restore pr test on develop branch

* add mask

* add post action to cancel slurm job

* remove readonly attribute on job log

* add debug info

* debug job log

* try stdin

* use stdin

* set default value avoid error

* try setting readonly on job log

* performance echo

* remove debug info

* use squeue to check slurm job status

* restore the lossed parm

* litmit retry times

* use exclusive to avoid port already in use

* optimize loop body

* remove partition

* add {} for variables

* set env variable for slurm partition

---------

Co-authored-by: qa-caif-cicd <qa-caif-cicd@pjlab.org.cn>

* refactor(tools): move interface.py and import it to web_demo (opendilab#195)

* move interface.py and import it to web_demo

* typo

* fix(ci): fix lint error

* fix(ci): fix lint error

---------

Co-authored-by: Sun Peng <sunpengsdu@gmail.com>
Co-authored-by: ChenQiaoling00 <qiaoling_chen@u.nus.edu>
Co-authored-by: Kai Chen <chenkaidev@gmail.com>
Co-authored-by: Yang Gao <Gary1546308416AL@gmail.com>
Co-authored-by: Changjiang GOU <gouchangjiang@gmail.com>
Co-authored-by: gouhchangjiang <gouhchangjiang@gmail.com>
Co-authored-by: vansin <msnode@163.com>
Co-authored-by: Ikko Eltociear Ashimine <eltociear@gmail.com>
Co-authored-by: YWMditto <46778265+YWMditto@users.noreply.github.com>
Co-authored-by: YWMditto <862779238@qq.com>
Co-authored-by: WRH <12756472+wangruohui@users.noreply.github.com>
Co-authored-by: liukuikun <24622904+Harold-lkk@users.noreply.github.com>
Co-authored-by: x54-729 <45304952+x54-729@users.noreply.github.com>
Co-authored-by: Shuo Zhang <zhangshuolove@live.com>
Co-authored-by: Miao Zheng <76149310+MeowZheng@users.noreply.github.com>
Co-authored-by: 黄婷 <huangting3@CN0014010744M.local>
Co-authored-by: ytxiong <45058324+yingtongxiong@users.noreply.github.com>
Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
Co-authored-by: kkscilife <126147887+kkscilife@users.noreply.github.com>
Co-authored-by: qa-caif-cicd <qa-caif-cicd@pjlab.org.cn>
Co-authored-by: hw <45089338+MorningForest@users.noreply.github.com>
Co-authored-by: Guoteng <32697156+SolenoidWGT@users.noreply.github.com>
Co-authored-by: wangguoteng.p <wangguoteng925@qq.com>
Co-authored-by: lvhan028 <lvhan_028@163.com>
Co-authored-by: zachtzy <141206206+zachtzy@users.noreply.github.com>
Co-authored-by: cx <759046501@qq.com>
Co-authored-by: Jaylin Lee <61487970+APX103@users.noreply.github.com>
Co-authored-by: del-zhenwu <dele.zhenwu@gmail.com>
Co-authored-by: Shaoyuan Xie <66255889+Daniel-xsy@users.noreply.github.com>
Co-authored-by: BigDong <yudongwang1226@gmail.com>
Co-authored-by: Wenwei Zhang <40779233+ZwwWayne@users.noreply.github.com>
Co-authored-by: huangting4201 <huangting3@sensetime.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
algo Add new algorithm or improve old one
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants