@@ -151,9 +151,20 @@ Set-LocalUser -Name travis -Password $pw
151151 "large" : "n2-standard-4" ,
152152 "x-large" : "n2-standard-8" ,
153153 "2x-large" : "n2-standard-16" ,
154+ "gpu-medium" : "n1-standard-8" ,
155+ "gpu-xlarge" : "n1-standard-8" ,
154156 }
155157)
156158
159+ func stringInSlice (a string , list []string ) bool {
160+ for _ , b := range list {
161+ if b == a {
162+ return true
163+ }
164+ }
165+ return false
166+ }
167+
157168type gceStartupScriptData struct {
158169 AutoImplode bool
159170 HardTimeoutMinutes int64
@@ -180,6 +191,53 @@ func (oe *gceOpError) Error() string {
180191 return strings .Join (errStrs , ", " )
181192}
182193
194+ type singleGpuMapping struct {
195+ GpuCount int64
196+ GpuType string
197+ DiskSize int64
198+ }
199+
200+ var gpuMedium = singleGpuMapping {
201+ GpuCount : 1 ,
202+ GpuType : "nvidia-tesla-t4" ,
203+ DiskSize : 300 ,}
204+ var gpuXLarge = singleGpuMapping {
205+ GpuCount : 1 ,
206+ GpuType : "nvidia-tesla-v100" ,
207+ DiskSize : 300 ,}
208+
209+ func GpuMapping (vmSize string ) (value singleGpuMapping ) {
210+ gpuMapping := map [string ] singleGpuMapping {
211+ "gpu-medium" : gpuMedium ,
212+ "gpu-xlarge" : gpuXLarge ,
213+ }
214+ return gpuMapping [vmSize ]
215+ }
216+
217+
218+ func GpuDefaultGpuCount (vmSize string ) (gpuCountInt int64 ) {
219+ return GpuMapping (vmSize ).GpuCount
220+ }
221+
222+ func GpuDefaultGpuDiskSize (vmSize string ) (gpuDiskSizeInt int64 ) {
223+ return GpuMapping (vmSize ).DiskSize
224+ }
225+
226+ func GpuDefaultGpuType (vmSize string ) (gpuTypeString string ) {
227+ return GpuMapping (vmSize ).GpuType
228+ }
229+
230+ func GPUType (varSize string ) string {
231+ switch varSize {
232+ case "gpu-medium" :
233+ return "gpu-medium"
234+ case "gpu-xlarge" :
235+ return "gpu-xlarge"
236+ default :
237+ return ""
238+ }
239+ }
240+
183241type gceAccountJSON struct {
184242 ClientEmail string `json:"client_email"`
185243 PrivateKey string `json:"private_key"`
@@ -827,7 +885,9 @@ func (p *gceProvider) Setup(ctx gocontext.Context) error {
827885
828886 machineTypes := []string {p .ic .MachineType , p .ic .PremiumMachineType }
829887 for _ , machineType := range gceVMSizeMapping {
830- machineTypes = append (machineTypes , machineType );
888+ if ! stringInSlice (machineType , machineTypes ) {
889+ machineTypes = append (machineTypes , machineType );
890+ }
831891 }
832892 for _ , zoneName := range append (zoneNames , p .alternateZones ... ) {
833893 for _ , machineType := range machineTypes {
@@ -1421,6 +1481,7 @@ func (p *gceProvider) imageSelect(ctx gocontext.Context, startAttributes *StartA
14211481
14221482 jobID , _ := context .JobIDFromContext (ctx )
14231483 repo , _ := context .RepositoryFromContext (ctx )
1484+ var gpuVMType = GPUType (startAttributes .VMSize )
14241485
14251486 if startAttributes .ImageName != "" {
14261487 imageName = startAttributes .ImageName
@@ -1434,6 +1495,7 @@ func (p *gceProvider) imageSelect(ctx gocontext.Context, startAttributes *StartA
14341495 OS : startAttributes .OS ,
14351496 JobID : jobID ,
14361497 Repo : repo ,
1498+ GpuVMType : gpuVMType ,
14371499 })
14381500
14391501 if err != nil {
@@ -1485,11 +1547,31 @@ func (p *gceProvider) buildInstance(ctx gocontext.Context, c *gceStartContext) (
14851547 Zone : c .zoneName ,
14861548 }
14871549
1550+ var gpuVMType = GPUType (c .startAttributes .VMSize )
1551+
1552+ machineType := p .ic .MachineType
1553+ if c .startAttributes .VMType == "premium" {
1554+ c .startAttributes .VMSize = "premium"
1555+ machineType = p .ic .PremiumMachineType
1556+ } else if c .startAttributes .VMSize != "" {
1557+ if mtype , ok := gceVMSizeMapping [c .startAttributes .VMSize ]; ok {
1558+ machineType = mtype ;
1559+ //storing converted machine type for instance size identification
1560+ if gpuVMType == "" {
1561+ c .startAttributes .VMSize = machineType
1562+ }
1563+ }
1564+ }
1565+
14881566 diskSize := p .ic .DiskSize
14891567 if c .startAttributes .OS == "windows" {
14901568 diskSize = p .ic .DiskSizeWindows
14911569 }
14921570
1571+ if gpuVMType != "" {
1572+ diskSize = GpuDefaultGpuDiskSize (gpuVMType )
1573+ }
1574+
14931575 diskInitParams := & compute.AttachedDiskInitializeParams {
14941576 SourceImage : c .image .SelfLink ,
14951577 DiskType : gcePdSSDForZone (c .zoneName ),
@@ -1506,18 +1588,6 @@ func (p *gceProvider) buildInstance(ctx gocontext.Context, c *gceStartContext) (
15061588 },
15071589 }
15081590
1509- machineType := p .ic .MachineType
1510- if c .startAttributes .VMType == "premium" {
1511- c .startAttributes .VMSize = "premium"
1512- machineType = p .ic .PremiumMachineType
1513- } else if c .startAttributes .VMSize != "" {
1514- if mtype , ok := gceVMSizeMapping [c .startAttributes .VMSize ]; ok {
1515- machineType = mtype ;
1516- //storing converted machine type for instance size identification
1517- c .startAttributes .VMSize = machineType
1518- }
1519- }
1520-
15211591 var ok bool
15221592 inst .MachineType , ok = p .machineTypeSelfLinks [gceMtKey (c .zoneName , machineType )]
15231593 if ! ok {
@@ -1532,6 +1602,19 @@ func (p *gceProvider) buildInstance(ctx gocontext.Context, c *gceStartContext) (
15321602 p .projectID ,
15331603 c .startAttributes .VMConfig .Zone ,
15341604 c .startAttributes .VMConfig .GpuType )
1605+ } else if gpuVMType != "" {
1606+ logger .WithField ("acceleratorConfig.AcceleratorType" , acceleratorConfig .AcceleratorType ).Debug ("Setting AcceleratorConfig" )
1607+ if ! strings .HasPrefix (acceleratorConfig .AcceleratorType , "https" ) {
1608+ notUrlAcceleratorType := GpuDefaultGpuType (gpuVMType )
1609+ logger .WithField ("notUrlAcceleratorType" , notUrlAcceleratorType ).Debug ("Retrieving AcceleratorType from defaults" )
1610+ logger .WithField ("AcceleratorCount" , p .ic .AcceleratorConfig .AcceleratorCount ).Debug ("Retrieving AcceleratorCount from defaults" )
1611+ acceleratorConfig .AcceleratorCount = GpuDefaultGpuCount (gpuVMType )
1612+ acceleratorConfig .AcceleratorType = fmt .Sprintf ("https://www.googleapis.com/compute/v1/projects/%s/zones/%s/acceleratorTypes/%s" ,
1613+ p .projectID ,
1614+ c .zoneName ,
1615+ notUrlAcceleratorType )
1616+ logger .WithField ("acceleratorConfig.AcceleratorType" , acceleratorConfig .AcceleratorType ).Debug ("Url for Accelerator Type is:" )
1617+ }
15351618 }
15361619
15371620 var subnetwork string
@@ -1595,6 +1678,7 @@ func (p *gceProvider) buildInstance(ctx gocontext.Context, c *gceStartContext) (
15951678 }
15961679
15971680 inst .GuestAccelerators = []* compute.AcceleratorConfig {}
1681+
15981682 if acceleratorConfig .AcceleratorCount > 0 {
15991683 logger .Debug ("GPU requested, setting acceleratorConfig" )
16001684 inst .GuestAccelerators = append (inst .GuestAccelerators , acceleratorConfig )
0 commit comments