diff --git a/client/README.md b/client/README.md index 9e676c6b..849ba4eb 100644 --- a/client/README.md +++ b/client/README.md @@ -370,6 +370,7 @@ type GPU struct { ID string `json:"id"` Vendor string `json:"vendor"` Device string `json:"device"` + Vram uint64 `json:"vram"` Contract uint64 `json:"contract"` } ``` diff --git a/client/node.go b/client/node.go index c5b70555..2d48753b 100644 --- a/client/node.go +++ b/client/node.go @@ -132,6 +132,7 @@ type GPU struct { ID string `json:"id"` Vendor string `json:"vendor"` Device string `json:"device"` + Vram uint64 `json:"vram"` Contract uint64 `json:"contract"` } diff --git a/docs/manual/api.md b/docs/manual/api.md index 250eb00f..2f58b00e 100644 --- a/docs/manual/api.md +++ b/docs/manual/api.md @@ -248,6 +248,7 @@ GPU { "id": "string" "vendor": "string" "device": "string", + "vram": "uint64", "contract": "uint64", } ``` diff --git a/pkg/capacity/capacity.go b/pkg/capacity/capacity.go index b18fc5f7..880846e5 100644 --- a/pkg/capacity/capacity.go +++ b/pkg/capacity/capacity.go @@ -1,14 +1,18 @@ package capacity import ( + "encoding/xml" + "fmt" "os" "os/exec" + "strconv" "strings" "syscall" "github.com/pkg/errors" "github.com/rs/zerolog/log" "github.com/shirou/gopsutil/host" + "github.com/threefoldtech/zosbase/pkg" "github.com/threefoldtech/zosbase/pkg/capacity/dmi" "github.com/threefoldtech/zosbase/pkg/capacity/smartctl" "github.com/threefoldtech/zosbase/pkg/gridtypes" @@ -182,3 +186,101 @@ func (r *ResourceOracle) GPUs() ([]PCI, error) { } return ListPCI(GPU) } + +// normalizeBusID converts a bus ID from format "00000000:01:00.0" to "0000:01:00.0" +func normalizeBusID(busID string) string { + parts := strings.Split(busID, ":") + if len(parts) != 3 { + return busID + } + domain := strings.TrimLeft(parts[0], "0") + if domain == "" { + domain = "0000" + } + domain = fmt.Sprintf("%0*s", 4, domain) + return fmt.Sprintf("%s:%s:%s", domain, parts[1], parts[2]) +} + +// DisplayNode represents a display device from lshw XML output +type DisplayNode struct { + Class string `xml:"class,attr"` + BusInfo string `xml:"businfo"` + Product string `xml:"product"` + Vendor string `xml:"vendor"` + Resources struct { + Memory []struct { + Value string `xml:"value,attr"` + } `xml:"resource"` + } `xml:"resources"` +} + +// DisplayList represents the root XML structure from lshw +type DisplayList struct { + Nodes []DisplayNode `xml:"node"` +} + +// GetGpuDevice gets the GPU information using lshw command +func GetGpuDevice(p *PCI) (pkg.GPUInfo, error) { + cmd := exec.Command("lshw", "-C", "display", "-xml") + output, err := cmd.Output() + if err != nil { + return pkg.GPUInfo{}, fmt.Errorf("failed to run lshw command: %w", err) + } + + var displayList DisplayList + err = xml.Unmarshal(output, &displayList) + if err != nil { + return pkg.GPUInfo{}, fmt.Errorf("failed to parse lshw XML output: %w", err) + } + + for _, node := range displayList.Nodes { + if node.Class != "display" { + continue + } + + busInfo := node.BusInfo + if !strings.HasPrefix(busInfo, "pci@") { + continue + } + + busID := strings.TrimPrefix(busInfo, "pci@") + normalizedBusID := normalizeBusID(busID) + + if normalizedBusID != p.Slot { + continue + } + + var vram uint64 = 0 + for _, resource := range node.Resources.Memory { + if strings.Contains(resource.Value, "-") { + parts := strings.Split(resource.Value, "-") + if len(parts) == 2 { + start := strings.TrimSpace(parts[0]) + end := strings.TrimSpace(parts[1]) + if startVal, err1 := strconv.ParseUint(start, 16, 64); err1 == nil { + if endVal, err2 := strconv.ParseUint(end, 16, 64); err2 == nil { + size := (endVal - startVal + 1) / (1024 * 1024) + if size > vram { + vram = size + } + } + } + } + } + } + + vendor, device, ok := p.GetDevice() + if !ok { + return pkg.GPUInfo{}, fmt.Errorf("failed to get vendor and device info") + } + + return pkg.GPUInfo{ + ID: p.ShortID(), + Vendor: vendor.Name, + Device: device.Name, + Vram: vram, + }, nil + } + + return pkg.GPUInfo{}, fmt.Errorf("gpu not found in lshw output") +} diff --git a/pkg/primitives/statistics.go b/pkg/primitives/statistics.go index e85abf08..635e71a0 100644 --- a/pkg/primitives/statistics.go +++ b/pkg/primitives/statistics.go @@ -315,10 +315,12 @@ func (s *statsStream) ListGPUs() ([]pkg.GPUInfo, error) { for _, pciDevice := range devices { id := pciDevice.ShortID() + gpu, _ := capacity.GetGpuDevice(&pciDevice) info := pkg.GPUInfo{ ID: id, Vendor: "unknown", Device: "unknown", + Vram: gpu.Vram, Contract: used[id], } diff --git a/pkg/provision.go b/pkg/provision.go index 3047b3a5..6cfa8b23 100644 --- a/pkg/provision.go +++ b/pkg/provision.go @@ -58,5 +58,6 @@ type GPUInfo struct { ID string `json:"id"` Vendor string `json:"vendor"` Device string `json:"device"` + Vram uint64 `json:"vram"` Contract uint64 `json:"contract"` }