Skip to content

Commit ee57d65

Browse files
authored
Override defaults if GPU (#664)
* first commit * missing comma * missing function * simplification * updates tables * logging no 1 * more debug * fixing things and log them * fixing * accelerator type is just a string not url * Revert "accelerator type is just a string not url" This reverts commit c09b6fe. * more logging and stuff * more logs * typo * more and more logging * use zone instead of region * styles * do no stack accelerator type * removing debug * always load defaults for gpu plan * forgotten line * Some nice logs * pasing only count * adding statement * disk spart way * cleanup * do not assign VMsize if gpu VM Type * Add gpu to query tags (#670) * added gpu to the tags list for quering images in api-selector * added debug lines * add gpu_vm_type to params in api selector * removed debug lines * Update CHANGELOG.md
1 parent b840303 commit ee57d65

File tree

4 files changed

+109
-13
lines changed

4 files changed

+109
-13
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/).
1212

1313
### Removed
1414

15+
### Added
16+
- Adding GPU Support
17+
1518
### Fixed
1619

1720
## [6.2.4] - 2019-10-29

backend/gce.go

Lines changed: 97 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -151,9 +151,20 @@ Set-LocalUser -Name travis -Password $pw
151151
"large": "n2-standard-4",
152152
"x-large": "n2-standard-8",
153153
"2x-large": "n2-standard-16",
154+
"gpu-medium": "n1-standard-8",
155+
"gpu-xlarge": "n1-standard-8",
154156
}
155157
)
156158

159+
func stringInSlice(a string, list []string) bool {
160+
for _, b := range list {
161+
if b == a {
162+
return true
163+
}
164+
}
165+
return false
166+
}
167+
157168
type gceStartupScriptData struct {
158169
AutoImplode bool
159170
HardTimeoutMinutes int64
@@ -180,6 +191,53 @@ func (oe *gceOpError) Error() string {
180191
return strings.Join(errStrs, ", ")
181192
}
182193

194+
type singleGpuMapping struct {
195+
GpuCount int64
196+
GpuType string
197+
DiskSize int64
198+
}
199+
200+
var gpuMedium = singleGpuMapping{
201+
GpuCount: 1,
202+
GpuType: "nvidia-tesla-t4",
203+
DiskSize: 300,}
204+
var gpuXLarge = singleGpuMapping{
205+
GpuCount: 1,
206+
GpuType: "nvidia-tesla-v100",
207+
DiskSize: 300,}
208+
209+
func GpuMapping(vmSize string) (value singleGpuMapping) {
210+
gpuMapping := map[string] singleGpuMapping{
211+
"gpu-medium": gpuMedium,
212+
"gpu-xlarge": gpuXLarge,
213+
}
214+
return gpuMapping[vmSize]
215+
}
216+
217+
218+
func GpuDefaultGpuCount(vmSize string) (gpuCountInt int64) {
219+
return GpuMapping(vmSize).GpuCount
220+
}
221+
222+
func GpuDefaultGpuDiskSize(vmSize string) (gpuDiskSizeInt int64) {
223+
return GpuMapping(vmSize).DiskSize
224+
}
225+
226+
func GpuDefaultGpuType(vmSize string) (gpuTypeString string) {
227+
return GpuMapping(vmSize).GpuType
228+
}
229+
230+
func GPUType(varSize string) string {
231+
switch varSize {
232+
case "gpu-medium":
233+
return "gpu-medium"
234+
case "gpu-xlarge":
235+
return "gpu-xlarge"
236+
default:
237+
return ""
238+
}
239+
}
240+
183241
type gceAccountJSON struct {
184242
ClientEmail string `json:"client_email"`
185243
PrivateKey string `json:"private_key"`
@@ -827,7 +885,9 @@ func (p *gceProvider) Setup(ctx gocontext.Context) error {
827885

828886
machineTypes := []string{p.ic.MachineType, p.ic.PremiumMachineType}
829887
for _, machineType := range gceVMSizeMapping {
830-
machineTypes = append(machineTypes, machineType);
888+
if !stringInSlice(machineType, machineTypes) {
889+
machineTypes = append(machineTypes, machineType);
890+
}
831891
}
832892
for _, zoneName := range append(zoneNames, p.alternateZones...) {
833893
for _, machineType := range machineTypes {
@@ -1421,6 +1481,7 @@ func (p *gceProvider) imageSelect(ctx gocontext.Context, startAttributes *StartA
14211481

14221482
jobID, _ := context.JobIDFromContext(ctx)
14231483
repo, _ := context.RepositoryFromContext(ctx)
1484+
var gpuVMType = GPUType(startAttributes.VMSize)
14241485

14251486
if startAttributes.ImageName != "" {
14261487
imageName = startAttributes.ImageName
@@ -1434,6 +1495,7 @@ func (p *gceProvider) imageSelect(ctx gocontext.Context, startAttributes *StartA
14341495
OS: startAttributes.OS,
14351496
JobID: jobID,
14361497
Repo: repo,
1498+
GpuVMType: gpuVMType,
14371499
})
14381500

14391501
if err != nil {
@@ -1485,11 +1547,31 @@ func (p *gceProvider) buildInstance(ctx gocontext.Context, c *gceStartContext) (
14851547
Zone: c.zoneName,
14861548
}
14871549

1550+
var gpuVMType = GPUType(c.startAttributes.VMSize)
1551+
1552+
machineType := p.ic.MachineType
1553+
if c.startAttributes.VMType == "premium" {
1554+
c.startAttributes.VMSize = "premium"
1555+
machineType = p.ic.PremiumMachineType
1556+
} else if c.startAttributes.VMSize != "" {
1557+
if mtype, ok := gceVMSizeMapping[c.startAttributes.VMSize]; ok {
1558+
machineType = mtype;
1559+
//storing converted machine type for instance size identification
1560+
if gpuVMType == "" {
1561+
c.startAttributes.VMSize = machineType
1562+
}
1563+
}
1564+
}
1565+
14881566
diskSize := p.ic.DiskSize
14891567
if c.startAttributes.OS == "windows" {
14901568
diskSize = p.ic.DiskSizeWindows
14911569
}
14921570

1571+
if gpuVMType != "" {
1572+
diskSize = GpuDefaultGpuDiskSize(gpuVMType)
1573+
}
1574+
14931575
diskInitParams := &compute.AttachedDiskInitializeParams{
14941576
SourceImage: c.image.SelfLink,
14951577
DiskType: gcePdSSDForZone(c.zoneName),
@@ -1506,18 +1588,6 @@ func (p *gceProvider) buildInstance(ctx gocontext.Context, c *gceStartContext) (
15061588
},
15071589
}
15081590

1509-
machineType := p.ic.MachineType
1510-
if c.startAttributes.VMType == "premium" {
1511-
c.startAttributes.VMSize = "premium"
1512-
machineType = p.ic.PremiumMachineType
1513-
} else if c.startAttributes.VMSize != "" {
1514-
if mtype, ok := gceVMSizeMapping[c.startAttributes.VMSize]; ok {
1515-
machineType = mtype;
1516-
//storing converted machine type for instance size identification
1517-
c.startAttributes.VMSize = machineType
1518-
}
1519-
}
1520-
15211591
var ok bool
15221592
inst.MachineType, ok = p.machineTypeSelfLinks[gceMtKey(c.zoneName, machineType)]
15231593
if !ok {
@@ -1532,6 +1602,19 @@ func (p *gceProvider) buildInstance(ctx gocontext.Context, c *gceStartContext) (
15321602
p.projectID,
15331603
c.startAttributes.VMConfig.Zone,
15341604
c.startAttributes.VMConfig.GpuType)
1605+
} else if gpuVMType != "" {
1606+
logger.WithField("acceleratorConfig.AcceleratorType", acceleratorConfig.AcceleratorType).Debug("Setting AcceleratorConfig")
1607+
if !strings.HasPrefix(acceleratorConfig.AcceleratorType, "https") {
1608+
notUrlAcceleratorType := GpuDefaultGpuType(gpuVMType)
1609+
logger.WithField("notUrlAcceleratorType", notUrlAcceleratorType).Debug("Retrieving AcceleratorType from defaults")
1610+
logger.WithField("AcceleratorCount", p.ic.AcceleratorConfig.AcceleratorCount).Debug("Retrieving AcceleratorCount from defaults")
1611+
acceleratorConfig.AcceleratorCount = GpuDefaultGpuCount(gpuVMType)
1612+
acceleratorConfig.AcceleratorType = fmt.Sprintf("https://www.googleapis.com/compute/v1/projects/%s/zones/%s/acceleratorTypes/%s",
1613+
p.projectID,
1614+
c.zoneName,
1615+
notUrlAcceleratorType)
1616+
logger.WithField("acceleratorConfig.AcceleratorType", acceleratorConfig.AcceleratorType).Debug("Url for Accelerator Type is:")
1617+
}
15351618
}
15361619

15371620
var subnetwork string
@@ -1595,6 +1678,7 @@ func (p *gceProvider) buildInstance(ctx gocontext.Context, c *gceStartContext) (
15951678
}
15961679

15971680
inst.GuestAccelerators = []*compute.AcceleratorConfig{}
1681+
15981682
if acceleratorConfig.AcceleratorCount > 0 {
15991683
logger.Debug("GPU requested, setting acceleratorConfig")
16001684
inst.GuestAccelerators = append(inst.GuestAccelerators, acceleratorConfig)

image/api_selector.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ func (as *APISelector) queryWithTags(ctx gocontext.Context, infra string, tags [
119119
bodyLines := []string{}
120120
lastJobID := uint64(0)
121121
lastRepo := ""
122+
gpuVMType := ""
122123

123124
for _, ts := range tags {
124125
qs := url.Values{}
@@ -127,6 +128,7 @@ func (as *APISelector) queryWithTags(ctx gocontext.Context, infra string, tags [
127128
qs.Set("limit", "1")
128129
qs.Set("job_id", fmt.Sprintf("%v", ts.JobID))
129130
qs.Set("repo", ts.Repo)
131+
qs.Set("gpu_vm_type", ts.GpuVMType)
130132
qs.Set("is_default", fmt.Sprintf("%v", ts.IsDefault))
131133
if len(ts.Tags) > 0 {
132134
qs.Set("tags", strings.Join(ts.Tags, ","))
@@ -135,6 +137,7 @@ func (as *APISelector) queryWithTags(ctx gocontext.Context, infra string, tags [
135137
bodyLines = append(bodyLines, qs.Encode())
136138
lastJobID = ts.JobID
137139
lastRepo = ts.Repo
140+
gpuVMType = ts.GpuVMType
138141
}
139142

140143
qs := url.Values{}
@@ -144,6 +147,7 @@ func (as *APISelector) queryWithTags(ctx gocontext.Context, infra string, tags [
144147
qs.Set("limit", "1")
145148
qs.Set("job_id", fmt.Sprintf("%v", lastJobID))
146149
qs.Set("repo", lastRepo)
150+
qs.Set("gpu_vm_type", gpuVMType)
147151

148152
bodyLines = append(bodyLines, qs.Encode())
149153

@@ -233,6 +237,7 @@ type tagSet struct {
233237

234238
JobID uint64
235239
Repo string
240+
GpuVMType string
236241
}
237242

238243
func (ts *tagSet) GoString() string {
@@ -244,6 +249,7 @@ func (as *APISelector) buildCandidateTags(params *Params) ([]*tagSet, error) {
244249
Tags: []string{},
245250
JobID: params.JobID,
246251
Repo: params.Repo,
252+
GpuVMType: params.GpuVMType,
247253
}
248254
candidateTags := []*tagSet{}
249255

@@ -255,6 +261,7 @@ func (as *APISelector) buildCandidateTags(params *Params) ([]*tagSet, error) {
255261
Tags: []string{tag},
256262
JobID: params.JobID,
257263
Repo: params.Repo,
264+
GpuVMType: params.GpuVMType,
258265
})
259266
}
260267

@@ -265,6 +272,7 @@ func (as *APISelector) buildCandidateTags(params *Params) ([]*tagSet, error) {
265272
Tags: tags,
266273
JobID: params.JobID,
267274
Repo: params.Repo,
275+
GpuVMType: params.GpuVMType,
268276
})
269277
}
270278

image/params.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,5 @@ type Params struct {
1010

1111
JobID uint64
1212
Repo string
13+
GpuVMType string
1314
}

0 commit comments

Comments
 (0)