Stream of Search (SoS): Learning to Search in Language

コンテンツ

1:Input: An initial policy parameter πi⁢n⁢i⁢tsubscript𝜋𝑖𝑛𝑖𝑡\pi_{init}italic_π start_POSTSUBSCRIPT italic_i italic_n italic_i italic_t end_POSTSUBSCRIPT, a given reward function R𝑅Ritalic_R, Advantage coefficient λ𝜆\lambdaitalic_λ.

2:π0←πi⁢n⁢i⁢t←subscript𝜋0subscript𝜋𝑖𝑛𝑖𝑡\pi_{0}\leftarrow\pi_{init}italic_π start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ← italic_π start_POSTSUBSCRIPT italic_i italic_n italic_i italic_t end_POSTSUBSCRIPT

3:πr⁢e⁢f←πi⁢n⁢i⁢t←subscript𝜋𝑟𝑒𝑓subscript𝜋𝑖𝑛𝑖𝑡\pi_{ref}\leftarrow\pi_{init}italic_π start_POSTSUBSCRIPT italic_r italic_e italic_f end_POSTSUBSCRIPT ← italic_π start_POSTSUBSCRIPT italic_i italic_n italic_i italic_t end_POSTSUBSCRIPT ▷▷\triangleright▷ Copy the SoS model to create a reference network.

4:πv⁢a⁢l⁢u⁢e←πi⁢n⁢i⁢t←subscript𝜋𝑣𝑎𝑙𝑢𝑒subscript𝜋𝑖𝑛𝑖𝑡\pi_{value}\leftarrow\pi_{init}italic_π start_POSTSUBSCRIPT italic_v italic_a italic_l italic_u italic_e end_POSTSUBSCRIPT ← italic_π start_POSTSUBSCRIPT italic_i italic_n italic_i italic_t end_POSTSUBSCRIPT ▷▷\triangleright▷ Copy the SoS model to create a value network.

5:for t𝑡titalic_t in 1⁢…⁢T1…𝑇1\ldots T1 … italic_T do

6:     Roll out πθt−1subscript𝜋subscript𝜃𝑡1\pi_{\theta_{t-1}}italic_π start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT to produce dataset Dt={(s1(t),a1(t),r1(t)),⋯,(sn(t),an(t),rn(t))}subscript𝐷𝑡subscriptsuperscript𝑠𝑡1subscriptsuperscript𝑎𝑡1subscriptsuperscript𝑟𝑡1⋯subscriptsuperscript𝑠𝑡𝑛subscriptsuperscript𝑎𝑡𝑛subscriptsuperscript𝑟𝑡𝑛D_{t}=\{(s^{(t)}_{1},a^{(t)}_{1},r^{(t)}_{1}),\cdots,(s^{(t)}_{n},a^{(t)}_{n},% r^{(t)}_{n})\}italic_D start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = { ( italic_s start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_a start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_r start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) , ⋯ , ( italic_s start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT , italic_a start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT , italic_r start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) }

7:     Update policy function according to

8:               θt=arg⁡maxθ⁡ℒA⁢P⁢A⁢(θ;Dt)subscript𝜃𝑡subscript𝜃subscriptℒ𝐴𝑃𝐴𝜃subscript𝐷𝑡\theta_{t}=\arg\max_{\theta}\mathcal{L}_{APA}(\theta;D_{t})italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = roman_arg roman_max start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT italic_A italic_P italic_A end_POSTSUBSCRIPT ( italic_θ ; italic_D start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ). ▷▷\triangleright▷ We omit the critic loss for simplicity

9:     where

10:               ℒA⁢P⁢A⁢(θ;D)=1|D|⁢∑(s,a)∈D(log⁡πθ⁢(a|s)−A⁢d⁢vπθt−1⁢(s,a)λ−log⁡πr⁢e⁢f⁢(a|s))2subscriptℒ𝐴𝑃𝐴𝜃𝐷1𝐷subscript𝑠𝑎𝐷superscriptsubscript𝜋𝜃conditional𝑎𝑠𝐴𝑑superscript𝑣subscript𝜋subscript𝜃𝑡1𝑠𝑎𝜆subscript𝜋𝑟𝑒𝑓conditional𝑎𝑠2\mathcal{L}_{APA}(\theta;D)=\frac{1}{|D|}\sum_{(s,a)\in D}\big{(}\log\pi_{% \theta}(a|s)-\frac{Adv^{\pi_{\theta_{t-1}}}(s,a)}{\lambda}-\log\pi_{ref}(a|s)% \big{)}^{2}caligraphic_L start_POSTSUBSCRIPT italic_A italic_P italic_A end_POSTSUBSCRIPT ( italic_θ ; italic_D ) = divide start_ARG 1 end_ARG start_ARG | italic_D | end_ARG ∑ start_POSTSUBSCRIPT ( italic_s , italic_a ) ∈ italic_D end_POSTSUBSCRIPT ( roman_log italic_π start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_a | italic_s ) - divide start_ARG italic_A italic_d italic_v start_POSTSUPERSCRIPT italic_π start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_POSTSUPERSCRIPT ( italic_s , italic_a ) end_ARG start_ARG italic_λ end_ARG - roman_log italic_π start_POSTSUBSCRIPT italic_r italic_e italic_f end_POSTSUBSCRIPT ( italic_a | italic_s ) ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT.

11:     If validation reward converges, update πr⁢e⁢fsubscript𝜋𝑟𝑒𝑓\pi_{ref}italic_π start_POSTSUBSCRIPT italic_r italic_e italic_f end_POSTSUBSCRIPT

12:               πr⁢e⁢f←πθt←subscript𝜋𝑟𝑒𝑓subscript𝜋subscript𝜃𝑡\pi_{ref}\leftarrow\pi_{\theta_{t}}italic_π start_POSTSUBSCRIPT italic_r italic_e italic_f end_POSTSUBSCRIPT ← italic_π start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT

13:end for

要約する
The article discusses a reinforcement learning algorithm that updates policy parameters based on a given reward function and an advantage coefficient. It outlines the process of rolling out the policy to generate a dataset, updating the policy function, and optimizing the policy parameters. The algorithm aims to maximize a specific loss function that considers the log probability of actions, the advantage function, and a reference policy. The policy is iteratively updated until the validation reward converges, at which point the reference policy is updated. This process is repeated for each time step in the training process.