Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Auto compile Lux models to reactant #665

Draft
wants to merge 15 commits into
base: main
Choose a base branch
from
Draft

Auto compile Lux models to reactant #665

wants to merge 15 commits into from

Conversation

avik-pal
Copy link
Member

@avik-pal avik-pal commented May 27, 2024

Example Usage

This follows the same structure as SimpleChains. User demands a conversion and provides an input prototype.

using Reactant, Lux, Random

model = Chain(Dense(10, 5, tanh), Dense(5, 2), softmax)

reactant_model = ToReactantAdaptor{true}(rand(Float32, 10, 3); force_compile_backward=true)(model)
ps, st = Lux.setup(Random.default_rng(), reactant_model)

x = randn(Float32, 10, 3)

reactant_model(x, ps, st)

Upstream Needs

TODOs

  • Compile Forward Pass
  • Compile Inference Pass
  • Compile VJP (using Enzyme)
    • Support the standard AD backends as well via ChainRules
    • Add Enzyme Rules to directly call the compiled function
  • Compile JVP (using Enzyme)
    • Support ForwardDiff
    • Add Enzyme ForwardDiff Rules
  • __make_reactant_array
  • ComponentArrays Special Handling
  • Add documentation (make sure that users know that this is experimental)
  • Add a tutorial similar to the SimpleChains one
    • Full Training Pipeline using Training API
    • Partial compilation for NN using XLA and remaining via LLVM (like Neural ODEs)
  • Add Reactant to our benchmark suite
  • Add tests
    • CPU
    • CUDA
  • (Nice to have) Extend to all models, not just stateless ones.
  • Compile the training loop
    • Fallback implementation for existing backends
    • Reactant Backend
    • Add to documentation

ext/LuxReactantExt.jl Outdated Show resolved Hide resolved
@avik-pal

This comment was marked as outdated.

@avik-pal

This comment was marked as outdated.

Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark Results

Benchmark suite Current: 8ce9707 Previous: 60c595e Ratio
Dense(2 => 2)/cpu/reverse/ReverseDiff (compiled)/(2, 128) 3669.375 ns 3646.75 ns 1.01
Dense(2 => 2)/cpu/reverse/Zygote/(2, 128) 7116.5 ns 7285 ns 0.98
Dense(2 => 2)/cpu/reverse/Tracker/(2, 128) 20819 ns 21210 ns 0.98
Dense(2 => 2)/cpu/reverse/ReverseDiff/(2, 128) 9519.8 ns 9781.666666666666 ns 0.97
Dense(2 => 2)/cpu/reverse/Flux/(2, 128) 8884.25 ns 9087.2 ns 0.98
Dense(2 => 2)/cpu/reverse/SimpleChains/(2, 128) 4494.625 ns 4453.888888888889 ns 1.01
Dense(2 => 2)/cpu/reverse/Enzyme/(2, 128) 1168.6785714285713 ns 1176.2706766917292 ns 0.99
Dense(2 => 2)/cpu/forward/NamedTuple/(2, 128) 1112.012658227848 ns 1112.28025477707 ns 1.00
Dense(2 => 2)/cpu/forward/ComponentArray/(2, 128) 1181.0820895522388 ns 1189.374074074074 ns 0.99
Dense(2 => 2)/cpu/forward/Flux/(2, 128) 1780.4406779661017 ns 1814.3181818181818 ns 0.98
Dense(2 => 2)/cpu/forward/SimpleChains/(2, 128) 179.55460992907803 ns 179.93324061196105 ns 1.00
Dense(20 => 20)/cpu/reverse/ReverseDiff (compiled)/(20, 128) 17252 ns 17212 ns 1.00
Dense(20 => 20)/cpu/reverse/Zygote/(20, 128) 17292 ns 17463 ns 0.99
Dense(20 => 20)/cpu/reverse/Tracker/(20, 128) 36588 ns 36689 ns 1.00
Dense(20 => 20)/cpu/reverse/ReverseDiff/(20, 128) 27992 ns 28147.5 ns 0.99
Dense(20 => 20)/cpu/reverse/Flux/(20, 128) 19677 ns 20058 ns 0.98
Dense(20 => 20)/cpu/reverse/SimpleChains/(20, 128) 16982 ns 16921 ns 1.00
Dense(20 => 20)/cpu/reverse/Enzyme/(20, 128) 4298 ns 4310.5 ns 1.00
Dense(20 => 20)/cpu/forward/NamedTuple/(20, 128) 3867.25 ns 3867.25 ns 1
Dense(20 => 20)/cpu/forward/ComponentArray/(20, 128) 3952.375 ns 3951.125 ns 1.00
Dense(20 => 20)/cpu/forward/Flux/(20, 128) 4775.357142857143 ns 4787.571428571428 ns 1.00
Dense(20 => 20)/cpu/forward/SimpleChains/(20, 128) 1664.2 ns 1659.1 ns 1.00
Conv((3, 3), 3 => 3)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 3, 128) 39058487.5 ns 38839150 ns 1.01
Conv((3, 3), 3 => 3)/cpu/reverse/Zygote/(64, 64, 3, 128) 58048876.5 ns 57478179 ns 1.01
Conv((3, 3), 3 => 3)/cpu/reverse/Tracker/(64, 64, 3, 128) 69429268 ns 68637336 ns 1.01
Conv((3, 3), 3 => 3)/cpu/reverse/ReverseDiff/(64, 64, 3, 128) 88231886 ns 80248739.5 ns 1.10
Conv((3, 3), 3 => 3)/cpu/reverse/Flux/(64, 64, 3, 128) 72608460 ns 66510498 ns 1.09
Conv((3, 3), 3 => 3)/cpu/reverse/SimpleChains/(64, 64, 3, 128) 12052157 ns 11601127 ns 1.04
Conv((3, 3), 3 => 3)/cpu/reverse/Enzyme/(64, 64, 3, 128) 8361274 ns 8302158.5 ns 1.01
Conv((3, 3), 3 => 3)/cpu/forward/NamedTuple/(64, 64, 3, 128) 7009826 ns 6958814.5 ns 1.01
Conv((3, 3), 3 => 3)/cpu/forward/ComponentArray/(64, 64, 3, 128) 6996013 ns 6935871 ns 1.01
Conv((3, 3), 3 => 3)/cpu/forward/Flux/(64, 64, 3, 128) 9927628.5 ns 9886349 ns 1.00
Conv((3, 3), 3 => 3)/cpu/forward/SimpleChains/(64, 64, 3, 128) 6387327.5 ns 6377484 ns 1.00
vgg16/cpu/reverse/Zygote/(32, 32, 3, 16) 714625608 ns 711495815.5 ns 1.00
vgg16/cpu/reverse/Zygote/(32, 32, 3, 64) 2809656510 ns 2802293498 ns 1.00
vgg16/cpu/reverse/Zygote/(32, 32, 3, 2) 143359748 ns 158450926 ns 0.90
vgg16/cpu/reverse/Tracker/(32, 32, 3, 16) 742274565 ns 745197995 ns 1.00
vgg16/cpu/reverse/Tracker/(32, 32, 3, 64) 2535181882 ns 2536517155 ns 1.00
vgg16/cpu/reverse/Tracker/(32, 32, 3, 2) 194702054 ns 186814591 ns 1.04
vgg16/cpu/reverse/Flux/(32, 32, 3, 16) 646623448.5 ns 698620045 ns 0.93
vgg16/cpu/reverse/Flux/(32, 32, 3, 64) 2730166614 ns 2703329300 ns 1.01
vgg16/cpu/reverse/Flux/(32, 32, 3, 2) 119575299 ns 122294200.5 ns 0.98
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 16) 194591286 ns 172044480 ns 1.13
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 64) 648527666 ns 643441503 ns 1.01
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 2) 45366635 ns 45114156 ns 1.01
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 16) 164210512 ns 163454975.5 ns 1.00
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 64) 636442834 ns 628139701 ns 1.01
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 2) 29909710.5 ns 29335904 ns 1.02
vgg16/cpu/forward/Flux/(32, 32, 3, 16) 186029419 ns 207955667.5 ns 0.89
vgg16/cpu/forward/Flux/(32, 32, 3, 64) 711562847 ns 722173872 ns 0.99
vgg16/cpu/forward/Flux/(32, 32, 3, 2) 35379449 ns 37423155 ns 0.95
Conv((3, 3), 64 => 64)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 64, 128) 1232480179.5 ns 1242027523.5 ns 0.99
Conv((3, 3), 64 => 64)/cpu/reverse/Zygote/(64, 64, 64, 128) 1864117910.5 ns 1847309072 ns 1.01
Conv((3, 3), 64 => 64)/cpu/reverse/Tracker/(64, 64, 64, 128) 2009631582.5 ns 1988297584 ns 1.01
Conv((3, 3), 64 => 64)/cpu/reverse/ReverseDiff/(64, 64, 64, 128) 2333653016.5 ns 2337208631 ns 1.00
Conv((3, 3), 64 => 64)/cpu/reverse/Flux/(64, 64, 64, 128) 1783708622.5 ns 1825164998 ns 0.98
Conv((3, 3), 64 => 64)/cpu/reverse/Enzyme/(64, 64, 64, 128) 344253318 ns 347875405.5 ns 0.99
Conv((3, 3), 64 => 64)/cpu/forward/NamedTuple/(64, 64, 64, 128) 316840031.5 ns 318366365 ns 1.00
Conv((3, 3), 64 => 64)/cpu/forward/ComponentArray/(64, 64, 64, 128) 318335263.5 ns 319738018 ns 1.00
Conv((3, 3), 64 => 64)/cpu/forward/Flux/(64, 64, 64, 128) 344631084 ns 452781616 ns 0.76
Conv((3, 3), 1 => 1)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 1, 128) 11976289 ns 11803413 ns 1.01
Conv((3, 3), 1 => 1)/cpu/reverse/Zygote/(64, 64, 1, 128) 18037586 ns 17882962 ns 1.01
Conv((3, 3), 1 => 1)/cpu/reverse/Tracker/(64, 64, 1, 128) 19230996 ns 19018033 ns 1.01
Conv((3, 3), 1 => 1)/cpu/reverse/ReverseDiff/(64, 64, 1, 128) 23857907 ns 23755630 ns 1.00
Conv((3, 3), 1 => 1)/cpu/reverse/Flux/(64, 64, 1, 128) 18084016 ns 17832966.5 ns 1.01
Conv((3, 3), 1 => 1)/cpu/reverse/SimpleChains/(64, 64, 1, 128) 1155306.5 ns 1148767 ns 1.01
Conv((3, 3), 1 => 1)/cpu/reverse/Enzyme/(64, 64, 1, 128) 2520956 ns 2512938 ns 1.00
Conv((3, 3), 1 => 1)/cpu/forward/NamedTuple/(64, 64, 1, 128) 2047653 ns 2035570 ns 1.01
Conv((3, 3), 1 => 1)/cpu/forward/ComponentArray/(64, 64, 1, 128) 2054475 ns 2023578.5 ns 1.02
Conv((3, 3), 1 => 1)/cpu/forward/Flux/(64, 64, 1, 128) 2070591 ns 2055760 ns 1.01
Conv((3, 3), 1 => 1)/cpu/forward/SimpleChains/(64, 64, 1, 128) 201787 ns 195727.5 ns 1.03
Dense(200 => 200)/cpu/reverse/ReverseDiff (compiled)/(200, 128) 289931.5 ns 288322 ns 1.01
Dense(200 => 200)/cpu/reverse/Zygote/(200, 128) 264083 ns 262603 ns 1.01
Dense(200 => 200)/cpu/reverse/Tracker/(200, 128) 365111 ns 354936.5 ns 1.03
Dense(200 => 200)/cpu/reverse/ReverseDiff/(200, 128) 406990 ns 400938 ns 1.02
Dense(200 => 200)/cpu/reverse/Flux/(200, 128) 272469 ns 270257 ns 1.01
Dense(200 => 200)/cpu/reverse/SimpleChains/(200, 128) 405988 ns 397421 ns 1.02
Dense(200 => 200)/cpu/reverse/Enzyme/(200, 128) 83225.5 ns 83306 ns 1.00
Dense(200 => 200)/cpu/forward/NamedTuple/(200, 128) 80735.5 ns 80271 ns 1.01
Dense(200 => 200)/cpu/forward/ComponentArray/(200, 128) 81221 ns 80581 ns 1.01
Dense(200 => 200)/cpu/forward/Flux/(200, 128) 86432 ns 85480 ns 1.01
Dense(200 => 200)/cpu/forward/SimpleChains/(200, 128) 104584 ns 104617 ns 1.00
Conv((3, 3), 16 => 16)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 16, 128) 193678547 ns 187932820.5 ns 1.03
Conv((3, 3), 16 => 16)/cpu/reverse/Zygote/(64, 64, 16, 128) 322350064 ns 321827872.5 ns 1.00
Conv((3, 3), 16 => 16)/cpu/reverse/Tracker/(64, 64, 16, 128) 390601991 ns 393773632.5 ns 0.99
Conv((3, 3), 16 => 16)/cpu/reverse/ReverseDiff/(64, 64, 16, 128) 455632147 ns 454117809 ns 1.00
Conv((3, 3), 16 => 16)/cpu/reverse/Flux/(64, 64, 16, 128) 342018490 ns 366877761 ns 0.93
Conv((3, 3), 16 => 16)/cpu/reverse/SimpleChains/(64, 64, 16, 128) 324867768.5 ns 309426428 ns 1.05
Conv((3, 3), 16 => 16)/cpu/reverse/Enzyme/(64, 64, 16, 128) 51185900 ns 51303991 ns 1.00
Conv((3, 3), 16 => 16)/cpu/forward/NamedTuple/(64, 64, 16, 128) 43468243 ns 43675671.5 ns 1.00
Conv((3, 3), 16 => 16)/cpu/forward/ComponentArray/(64, 64, 16, 128) 43323005 ns 43447693 ns 1.00
Conv((3, 3), 16 => 16)/cpu/forward/Flux/(64, 64, 16, 128) 59553933.5 ns 49289683 ns 1.21
Conv((3, 3), 16 => 16)/cpu/forward/SimpleChains/(64, 64, 16, 128) 27964827 ns 28489085 ns 0.98
Dense(2000 => 2000)/cpu/reverse/ReverseDiff (compiled)/(2000, 128) 18639172.5 ns 18511523 ns 1.01
Dense(2000 => 2000)/cpu/reverse/Zygote/(2000, 128) 19409965 ns 19373919.5 ns 1.00
Dense(2000 => 2000)/cpu/reverse/Tracker/(2000, 128) 23095282.5 ns 22860858 ns 1.01
Dense(2000 => 2000)/cpu/reverse/ReverseDiff/(2000, 128) 24104485.5 ns 23821494.5 ns 1.01
Dense(2000 => 2000)/cpu/reverse/Flux/(2000, 128) 19543279.5 ns 19452776.5 ns 1.00
Dense(2000 => 2000)/cpu/reverse/Enzyme/(2000, 128) 6497967 ns 6471809.5 ns 1.00
Dense(2000 => 2000)/cpu/forward/NamedTuple/(2000, 128) 6517959 ns 6467840.5 ns 1.01
Dense(2000 => 2000)/cpu/forward/ComponentArray/(2000, 128) 6472398.5 ns 6458192 ns 1.00
Dense(2000 => 2000)/cpu/forward/Flux/(2000, 128) 6506948 ns 6475071.5 ns 1.00

This comment was automatically generated by workflow using github-action-benchmark.

Copy link

codecov bot commented May 27, 2024

Codecov Report

Attention: Patch coverage is 0.62500% with 159 lines in your changes missing coverage. Please review.

Project coverage is 80.26%. Comparing base (60c595e) to head (8ce9707).
Report is 26 commits behind head on main.

Files Patch % Lines
ext/LuxReactantExt/layer.jl 0.00% 101 Missing ⚠️
src/transform/reactant.jl 0.00% 15 Missing ⚠️
ext/LuxReactantExt/utils.jl 0.00% 13 Missing ⚠️
ext/LuxReactantExt/train.jl 0.00% 12 Missing ⚠️
src/layers/extension.jl 0.00% 10 Missing ⚠️
src/contrib/training.jl 0.00% 6 Missing ⚠️
src/utils.jl 0.00% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #665      +/-   ##
==========================================
- Coverage   87.11%   80.26%   -6.85%     
==========================================
  Files          50       55       +5     
  Lines        2515     2671     +156     
==========================================
- Hits         2191     2144      -47     
- Misses        324      527     +203     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@avik-pal

This comment was marked as outdated.

@wsmoses
Copy link
Contributor

wsmoses commented May 27, 2024

@avik-pal can this be done in a way that reactant compiles the whole update, not just the gradient as separate from the inference pass. Specifically, I expect there to be a substantial perf improvement from doing so -- including the model update actually fully occuring in place.

E.g. the function reactant compiles being something like

function update(model, x, learning_rate)
   grads = gradient(model, x)
   update!(model, grads, learning_rate[1])
   nothing
end

@avik-pal
Copy link
Member Author

Not with the layers API. Currently, if we can accelerate just the neural network part, I would consider it a good win. Also, having it like this makes it possible to use regular Julia ops for cases where we can't compile to Reactant, for example, the ODE solves happen in Julia and the neural network is via XLA.

We can add AutoReactant for the training API, where we can compile the entire pass in the first call, and reuse it in subsequent calls (similar to what we do for Enzyme).

@avik-pal

This comment was marked as outdated.

@avik-pal
Copy link
Member Author

Okay things are working mostly now, we just need a copyto! for TracedRArray

ext/LuxReactantExt.jl Outdated Show resolved Hide resolved
ext/LuxReactantExt.jl Outdated Show resolved Hide resolved
@avik-pal
Copy link
Member Author

avik-pal commented May 27, 2024

using Reactant, Lux, Random, ComponentArrays

model = Chain(Dense(10, 5, tanh), Dense(5, 2), softmax)

reactant_model = ToReactantAdaptor{true}(rand(Float32, 10, 3))(model)

Gives me a Pipeline Failed with error: size of operand dimension 0 (3) is not equal to 1 or size of result dimension 0 (2). But the gradient seems to work fine for Enzyme with the XLA compilation.

@avik-pal avik-pal force-pushed the ap/reactant branch 2 times, most recently from 2c67548 to fb7ea0a Compare May 29, 2024 02:05
@wsmoses
Copy link
Contributor

wsmoses commented May 29, 2024

@avik-pal you should oopen an issue with the pipeline error on Reactant, once the prereqs are merged

@inline __try_similar_structure(x, y) = fmap(__try_similar_structure, x, y)

# Reactant doesn't handle mixed eltypes that well, so we will first try to compile it as
# a usual julia function. However, if that fails, we will type cast and try to recompile.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should be able to fix mixed eltypes if you have an mwe by chance

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

using Reactant, Lux, Random

model = Chain(Dense(10, 5, tanh), Dense(5, 2), softmax)

reactant_model = ToReactantAdaptor{true}(rand(10, 3))(model)

ps = __try_similar_structure(Lux.__named_tuple(ps), l.concrete_ps)
ps = l.adaptor(ps)
l.eltype_adaptor !== nothing && (ps = adapt(l.eltype_adaptor, ps))
ps = __make_concrete_array(ps)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this be set up where the ps are already reactant arrays so we don't need to call __make_concrete_array

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[hope would be avoiding data shuffling, especially cpu<->gpu of the whole model]

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should not be hit in most cases. I wasn't able to compile a ComponentArrays based version yet (needs a make_tracer overload), so it is a temporary solution.

The correct use of this should hit L163

src/contrib/training.jl Outdated Show resolved Hide resolved
src/contrib/training.jl Outdated Show resolved Hide resolved
@avik-pal
Copy link
Member Author

@wsmoses seems like an incorrect generation?

Module:
module attributes {transform.with_named_sequence} {
  func.func @main(%arg0: tensor<1x5xf32>, %arg1: tensor<1x5xf32>, %arg2: tensor<3x10xf32>, %arg3: tensor<10x5xf32>, %arg4: tensor<10x5xf32>) -> (tensor<1x5xf32>, tensor<f32>, tensor<10x5xf32>) {
    %cst = stablehlo.constant dense<1.000000e+00> : tensor<5x3xf32>
    %cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %cst_1 = arith.constant dense<0.000000e+00> : tensor<5xf32>
    %0 = stablehlo.reshape %arg1 : (tensor<1x5xf32>) -> tensor<5xf32>
    %1 = stablehlo.dot_general %arg4, %arg2, contracting_dims = [0] x [1], precision = [DEFAULT, DEFAULT] : (tensor<10x5xf32>, tensor<3x10xf32>) -> tensor<5x3xf32>
    %2 = stablehlo.broadcast_in_dim %0, dims = [0] : (tensor<5xf32>) -> tensor<5x3xf32>
    %3 = stablehlo.add %1, %2 : tensor<5x3xf32>
    %4 = stablehlo.tanh %3 : tensor<5x3xf32>
    %5 = stablehlo.reduce(%4 init: %cst_0) applies stablehlo.add across dimensions = [0, 1] : (tensor<5x3xf32>, tensor<f32>) -> tensor<f32>
    %6 = stablehlo.multiply %4, %4 : tensor<5x3xf32>
    %7 = stablehlo.subtract %cst, %6 : tensor<5x3xf32>
    %8 = stablehlo.reduce(%7 init: %cst_1) across dimensions = [1] : (tensor<5x3xf32>, tensor<5xf32>) -> tensor<5xf32>
     reducer(%arg5: tensor<5xf32>, %arg6: tensor<5xf32>)  {
      %13 = stablehlo.add %arg5, %arg6 : tensor<5xf32>
      stablehlo.return %13 : tensor<5xf32>
    }
    %9 = stablehlo.dot_general %arg2, %7, contracting_dims = [0] x [1], precision = [DEFAULT, DEFAULT] : (tensor<3x10xf32>, tensor<5x3xf32>) -> tensor<10x5xf32>
    %10 = stablehlo.add %arg3, %9 : tensor<10x5xf32>
    %11 = stablehlo.reshape %8 : (tensor<5xf32>) -> tensor<1x5xf32>
    %12 = stablehlo.add %arg0, %11 : tensor<1x5xf32>
    return %12, %5, %10 : tensor<1x5xf32>, tensor<f32>, tensor<10x5xf32>
  }
}
terminate called after throwing an instance of 'xla::XlaRuntimeError'
  what():  UNKNOWN: <unknown>:0: error: Reduction function must return a scalar or tuple of scalars but returns shape: f32[5]: 
<unknown>:0: note: see current operation: "func.return"(%15, %8, %13) : (tensor<1x5xf32>, tensor<f32>, tensor<10x5xf32>) -> ()

@wsmoses
Copy link
Contributor

wsmoses commented May 31, 2024

@avik-pal the lux fixes (and named tuple) just landed and were released.

I'll give the reduction error a go shortly, but at minimum we can see what works (and perhaps mark that as expected broken to start with)

@avik-pal
Copy link
Member Author

avik-pal commented Jun 1, 2024

Currently the julia session crashes because of the broken reverse pass, so can't mark it broken

src/transform/reactant.jl Outdated Show resolved Hide resolved
@avik-pal
Copy link
Member Author

avik-pal commented Jun 1, 2024

Can we have a no copy transfer between Julia AbstractArrays and Reactant/XLA Arrays? This makes life simpler to support wrapper types like ComponentArrays.

Also we can keep the parameters as regular Julia arrays which works more nicely with the current optimisers and such

@wsmoses
Copy link
Contributor

wsmoses commented Jun 1, 2024 via email

@avik-pal
Copy link
Member Author

avik-pal commented Jun 1, 2024

Also for better performance the optimizers themselves are compiled by reactant

Right but I don't think we would be able to compile NeuralODE style models yet right? So having an eager version that can perform operations directly on RArrays seems like a good tradeoff to run part of the model is regular Julia.

I might pull out the AutoReactant code (compiling the training iteration) into a separate PR because that would be easier to merge.

@wsmoses
Copy link
Contributor

wsmoses commented Jun 1, 2024 via email

@avik-pal avik-pal mentioned this pull request Jun 1, 2024
18 tasks
@avik-pal avik-pal added the xla label Jun 1, 2024
@avik-pal avik-pal force-pushed the main branch 5 times, most recently from 25325ea to e931a5e Compare June 16, 2024 02:10
@wsmoses
Copy link
Contributor

wsmoses commented Jun 18, 2024

@avik-pal fix has landed, can we retry this?

@avik-pal
Copy link
Member Author

This one is too broadly scoped, so I will hold it off.

First, I want to finish #673, which compiles the entire training loop and doesn't need to worry about users doing unwanted things to the parameters.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants