Description
Extracted from another issue: (note: some typo's may be fixed in the extraction.)
NiklasGustafsson Apr 11, 2023
Last time I checked, both the tutorials and examples at dotnet/TorchSharpExamples built and ran on Linux and MacOS, as well as Windows. If that has regressed, please file a bug in that repo.
TorchSharp is a thin library on top of libtorch, and the API design was done to make it straightforward to build on the plethora of Python-based examples and tutorials that are out there, since we do not have the resources to create our own.
I can (most of the time) copy-and-paste tensor and module expressions from Python into C#, but there are inherent differences that cannot be overcome without programmer involvement:
Python and .NET memory management are different.
Python's syntax for passing arguments by name is different from C#'s.
We can't mimic the module invocation syntax module(x) without resorting to using dynamic, so (with the latest release) it's module.call(x) in C#.
Python's and C#'s statement/class/method, etc. syntax are obviously very different.
I'm not sure what the old API you are referring to is -- we switched to Python-like naming conventions, following the SciSharp community in that regard, and the PyTorch scope hierarchy (which forced us to use a lot of static classes everywhere) a very long time ago. The examples and tutorials online at dotnet/TorchSharpExamples do not use the old version of the APIs.
Here are some other resources:
https://github.com/dotnet/TorchSharp/wiki
https://github.com/dotnet/TorchSharpExamples/tree/main/src
https://github.com/dotnet/TorchSharpExamples/tree/main/tutorials/CSharp
michieal Apr 14, 2023
Well, Ideally, I would like to use it to make a LLaMa or Alpaca implementation in C#. But, the "simple test" that I used to "get to know this" was this code:
import torch
model1_path = "./pytorch_model-00001-of-00003.bin"
model2_path = "./pytorch_model-00002-of-00003.bin"
model3_path = "./pytorch_model-00003-of-00003.bin"
merged_model_path = "./pytorch_model-13B.bin"
model1 = torch.load(model1_path, map_location=torch.device('cpu'))
model2 = torch.load(model2_path, map_location=torch.device('cpu'))
model3 = torch.load(model3_path, map_location=torch.device('cpu'))
# merge the models into a single dictionary
merged_model = {"model1": model1, "model2": model2, "model3": model3}
torch.save(merged_model, merged_model_path)
I mean, the python script works, I tested it earlier. I would like to make a c# version of this, so I don't have to have everything hard coded. But, the other day, I couldn't even do that much.
michieal Apr 14, 2023
You mentioned loading the modules using module(x)... can you tell me more about that?
I think that was one of the main points of failure that I experienced. Like, there's nothing up front that says to do that (that I saw), and then there's also the concept of that one has no idea what modules that the command can load; or what x should be? is it a string? is it...? etc.
NiklasGustafsson Apr 14, 2023
When you pass data into a module in Python, you treat it as a callable object, and the forward method is called:
input = ...
module = ...
output = module(input)
Since C# has no operator() that can be overloaded, unlike C++, we cannot replicate that syntax in C# without resorting to dynamic, which we don't want to do. Therefore, you have to call call on the module, which allows hooks to be invoked, or you can call forward directly.
michieal Apr 14, 2023
okay, the code has statements like these: class FeedForward(nn.Module): and in it, it defines a forward function... I know that TS does forward functions (I've read at least that much, lol)... is this where I load a module using module(x) or, is this creating a new module?
I guess, I am asking how to interpret some of the python, to know when to use the module command.
NiklasGustafsson Apr 14, 2023
That just means that FeedForward is derived from nn.Module. In C#, you should derive from one of the Module<T...> classes, preferably. The <T...> signature determines the signature of the forward function, which contains the logic of the forward pass. The backward pass is determined via autograd in the native runtime.
So, in Python you would call the forward function directly, if you want, but you usually treat the module class as a callable, i.e. a function-like object. Think of the Python module as having overloaded operator(...) by defining forward(...)
michieal Apr 14, 2023
ahhh okay. I was just checking on that, what about the use of a delegate to call it like a function? (asking because it was suggested.)
but, when I go to build those parts, I definite it as a class, that derives from nn.Module<type, type>, and fill in the two types from specifically the forward function's two types... correct?
NiklasGustafsson Apr 14, 2023
Technically, it's Module<T, TResult>
, where TResult is the return type of forward and T is any number of types that form its signature (I think we have it defined up to <T1,...,T6,TResult>.
There are some modules that have multiple forward signatures (Python deals with this dynamically), in which case you have to mix in IModule<T,...TResult> for anything you don't consider mainstream. Specifically, Sequential only accepts IModule<Tensor,Tensor> components, so if your module has that (which most do), then that should be the main one. That said, multiple forwards is an uncommon situation.
michieal Apr 14, 2023
the code only has single forward statements per class def. Which, I guess would be module definition?
NiklasGustafsson Apr 14, 2023
Unfortunately, you have to look at the logic -- Python doesn't allow overloading of methods, so the forward() will figure out what was passed inside the body. In TorchSharp, we insist on static typing... :-)
michieal Apr 14, 2023
Well, here's the smallest code snippet from the source code. this is in the Model.py file. (I figure that it's small enough to work with here, to get an understanding)
class FeedForward(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
multiple_of: int,
):
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
self.w1 = ColumnParallelLinear(
dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x
)
self.w2 = RowParallelLinear(
hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x
)
self.w3 = ColumnParallelLinear(
dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x
)
def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))
I'm guessing, that I would use TorchScript to construct this? or, am I off there?
Also, I am today years old learning that python has lambda declarations. lol.
NiklasGustafsson Apr 14, 2023
No, you would translate to C# manually. It would look something like (I didn't try to compile it):
public class FeedForward : torch.nn.Module<Tensor,Tensor>
{
private ColumnParallelLinear w1;
private RowParallelLinear w2;
private ColumnParallelLinear w3;
public FeedForward(int dim, int hidden_dim, int multiple_of) : base(nameof(FeedForward))
{
var hidden_dim = (int) 2 * hidden_dim / 3.0;
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) / multiple_of); // The last division must be integer division.
w1 = new ColumnParallelLinear(dim, hidden_dim, bias: false, gather_output: false, init_method: x => x);
w2 = new RowParallelLinear(hidden_dim, dim, bias: false, gather_output: false, init_method: x => x);
w3 = new ColumnParallelLinear(dim, hidden_dim, bias: false, gather_output: false, init_method: x => x);
RegisterComponents();
}
public override Tensor forward(Tensor x)
{
using _ = torch.NewDisposeScope();
return w2.forward(functional.silu(w1.forward(x)) * w3.foward(x)).MoveToOuterDisposeScope();
}
}
You can definitely use TorchScript, too, if PyTorch is able to export the model. However, it will then be a black box and you won't be able to modify it or use it to learn the details of how to use TorchSharp. The one benefit is you don't have to mess with translating the code.
michieal Apr 14, 2023
Well, I would rather learn. I'm not a fan of black boxes, especially in regards to code. And, my design goals means that I would need to use LLaMA in conjunction with a pre-filtering AI module to convert the user input into a viable input for the llama section. (So trying to not toss out the word "module" all over, and confuse the subject. lol.)
For that part, I was thinking that a BERT model would work well, as I am trying to (ultimately) make an AI assistant that you can ask questions and get a creative / helpful / mostly factual response.
NiklasGustafsson Apr 14, 2023
Just to give you an idea about the difference between call() and forward(), here's the Module<T,TResult> implementation of call:
public TResult call(T input)
{
// Call pre-hooks, if available.
foreach (var hook in pre_hooks.Values) {
var modified = hook(this, input);
if (modified is not null)
input = modified;
}
var result = forward(input);
// Call post-hooks, if available.
foreach (var hook in post_hooks.Values) {
var modified = hook(this, input, result);
if (modified is not null)
result = modified;
}
return result;
}
You should only implement (i.e. override) forward in your custom module.