// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 import 'package:code_builder/code_builder.dart'; import 'package:smithy/ast.dart'; import 'package:smithy_codegen/smithy_codegen.dart'; import 'package:smithy_codegen/src/generator/generation_context.dart'; import 'package:smithy_codegen/src/generator/generator.dart'; import 'package:smithy_codegen/src/generator/serialization/protocol_traits.dart'; import 'package:smithy_codegen/src/generator/types.dart'; import 'package:smithy_codegen/src/util/protocol_ext.dart'; import 'package:smithy_codegen/src/util/shape_ext.dart'; import 'package:smithy_codegen/src/util/symbol_ext.dart'; class OperationGenerator extends LibraryGenerator with OperationGenerationContext { OperationGenerator( super.shape, CodegenContext context, { super.smithyLibrary, }) : super(context: context); @override String get className => shape.dartName(context); @override Library generate() { if (httpTrait != null) { builder.body.add(_operationClass); } return builder.build(); } /// The operation's implementation class. Class get _operationClass => Class( (c) => c ..docs.addAll([ if (shape.hasDocs(context)) shape.formattedDocs(context), ]) ..name = className ..extend = paginatedTraits == null ? DartTypes.smithy.httpOperation( inputPayload.symbol.unboxed, inputSymbol, outputPayload.symbol.unboxed, outputSymbol, ) : DartTypes.smithy.paginatedHttpOperation( inputPayload.symbol.unboxed, inputSymbol, outputPayload.symbol.unboxed, outputSymbol, paginatedTraits!.inputToken?.symbol.unboxed ?? DartTypes.core.void$, paginatedTraits!.pageSize?.symbol.unboxed ?? DartTypes.core.void$, paginatedTraits!.items?.symbol.unboxed ?? outputSymbol, ) ..constructors.add(_constructor) ..fields.addAll([ _protocolsGetter, ..._httpOverrides.whereType(), ...shape.protocolFields(context), ]) ..methods.addAll([ ..._httpOverrides.whereType(), ..._paginatedMethods, ]), ); Constructor get _constructor => Constructor( (ctor) => ctor ..docs.addAll([ if (shape.hasDocs(context)) shape.formattedDocs(context), ]) ..optionalParameters.addAll(shape.constructorParameters(context)) ..initializers.addAll( shape.constructorParameters(context).where((p) => !p.toThis).map( (field) => refer('_${field.name}').assign(refer(field.name)).code, ), ), ); /// The statements of the HTTP request builder. Iterable get _httpRequestBuilder sync* { final builder = refer('b'); final input = refer('input'); yield builder .property('method') .assign(literalString(httpTrait!.method)) .statement; final uri = httpTrait!.uri; // S3 requires special treatment of the path // https://awslabs.github.io/smithy/1.0/spec/aws/customizations/s3-customizations.html final isS3 = context.service?.resolvedService?.sdkId == 'S3'; if (isS3 && uri.contains(RegExp('^/{Bucket}'))) { yield builder .property('path') .assign( refer('_s3ClientConfig').property('usePathStyle').conditional( // `raw` because some AWS paths use the '$' char. literalString(uri, raw: true), // Change `/{Bucket}/{Key+}?param` to `/{Key+}?param` and // `/{Bucket}` to `/` literalString( uri.replaceFirst(RegExp('^/{Bucket}/?'), '/'), raw: true, ), ), ) .statement; } else { yield builder .property('path') // `raw` because some AWS paths use the '$' char. .assign(literalString(uri, raw: true)) .statement; } final hostPrefix = shape.getTrait()?.hostPrefix; if (isS3 && uri.contains(RegExp('^/{Bucket}'))) { yield builder .property('hostPrefix') .assign( refer('_s3ClientConfig').property('usePathStyle').conditional( literal(hostPrefix), literalString('{Bucket}.${hostPrefix ?? ''}'), ), ) .statement; } else if (hostPrefix != null) { yield builder .property('hostPrefix') .assign(literalString(hostPrefix)) .statement; } for (final entry in httpInputTraits.httpHeaders.entries) { yield _inputHttpHeader( literalString(entry.key), entry.value, input.property(entry.value.dartName(ShapeType.structure)), isNullable: entry.value.isNullable(context, inputShape), ); } if (httpInputTraits.httpPrefixHeaders != null) { yield _httpPrefixedHeaders(httpInputTraits.httpPrefixHeaders!); } for (final entry in httpInputTraits.httpQuery.entries) { yield _httpQuery( literalString(entry.key), entry.value, input.property(entry.value.dartName(ShapeType.structure)), isNullable: entry.value.isNullable(context, inputShape), ); } if (httpInputTraits.httpQueryParams != null) { yield _httpQueryParameters(httpInputTraits.httpQueryParams!); } // HTTP checksums (dependent on input) final checksumTrait = shape.getTrait(); if (checksumTrait != null) { final checksumRequired = checksumTrait.requestChecksumRequired ?? false; final requestMemberName = checksumTrait.requestAlgorithmMember; if (requestMemberName == null) { if (checksumRequired) { yield builder.property('requestInterceptors').property('add').call([ DartTypes.smithyAws.withChecksum.constInstance([]), ]).statement; } } else { final requestMember = inputShape.members[requestMemberName]!; final memberIsNullable = requestMember.isNullable(context, inputShape); final inputProperty = input.property( requestMember.dartName(ShapeType.structure), ); if (checksumRequired) { yield builder.property('requestInterceptors').property('add').call([ DartTypes.smithyAws.withChecksum.newInstance([ inputProperty.nullableProperty('value', memberIsNullable), ]), ]).statement; } else { yield builder.property('requestInterceptors').property('add').call([ DartTypes.smithyAws.withChecksum.newInstance([ (memberIsNullable ? inputProperty.nullChecked : inputProperty) .property('value'), ]), ]).wrapWithBlockIf( inputProperty.notEqualTo(literalNull), memberIsNullable, ); } } } } /// Adds the header to the request's headers map. Code _inputHttpHeader( Expression key, Shape value, Expression valueRef, { required bool isNullable, }) { final builder = refer('b'); final toStringExp = valueToString( (isNullable ? valueRef.nullChecked : valueRef), value, isHeader: true, ); final addHeader = builder.property('headers').index(key).assign(toStringExp).statement; final targetShape = value is MemberShape ? context.shapeFor(value.target) : value; return addHeader // From restJson1 test suite: // "Do not send null values, empty strings, or empty lists over the wire in headers" .wrapWithBlockIf( (isNullable ? valueRef.nullChecked : valueRef).property('isNotEmpty'), const [ShapeType.list, ShapeType.set, ShapeType.string] .contains(targetShape.getType()) && !targetShape.hasTrait() && !targetShape.isEnum, ) .wrapWithBlockIf( valueRef.notEqualTo(literalNull), isNullable, ); } /// Adds the prefixed headers to the request's headers map. Code _httpPrefixedHeaders(HttpPrefixHeaders headers) { final mapShape = context.shapeFor(headers.member.target) as MapShape; final mapRef = refer('input').property(headers.member.dartName(ShapeType.structure)); final isNullableMap = headers.member.isNullable(context, inputShape); final valueTarget = context.shapeFor(mapShape.value.target); return Block.of([ const Code('for (var entry in '), (isNullableMap ? mapRef.nullChecked : mapRef) .property('toMap') .call([]) .property('entries') .code, const Code(') {'), _inputHttpHeader( literalString('${headers.trait.value}\${entry.key}'), valueTarget, refer('entry').property('value'), isNullable: valueTarget.isNullable(context, mapShape), ), const Code('}'), ]).wrapWithBlockIf(mapRef.notEqualTo(literalNull), isNullableMap); } /// Adds the shape to the request's query parameters. Code _httpQuery( Expression key, Shape value, Expression valueRef, { required bool isNullable, }) { final builder = refer('b'); final targetShape = value is MemberShape ? context.shapeFor(value.target) : value; if (targetShape is CollectionShape) { final isNullableMember = targetShape.member.isNullable(context, targetShape); Expression memberRef = refer('value'); if (isNullableMember) { memberRef = memberRef.nullChecked; } return Block.of([ const Code('for (var value in '), (isNullable ? valueRef.nullChecked : valueRef).code, const Code(') {'), _httpQuery( key, targetShape.member, memberRef, isNullable: isNullableMember, ), const Code('}'), ]).wrapWithBlockIf(valueRef.notEqualTo(literalNull), isNullable); } final toStringExp = valueToString( (isNullable ? valueRef.nullChecked : valueRef), targetShape, isHeader: false, ); final addParam = builder .property('queryParameters') .property('add') .call([key, toStringExp]).statement; return addParam.wrapWithBlockIf( valueRef.notEqualTo(literalNull), isNullable, ); } /// Adds all query parameters in a map to the request's query parameters. Code _httpQueryParameters(MemberShape queryParameters) { final targetShape = context.shapeFor(queryParameters.target) as MapShape; final valueShape = context.shapeFor(targetShape.value.target); final isNullable = queryParameters.isNullable(context, inputShape); final mapRef = refer('input').property(queryParameters.dartName(ShapeType.structure)); var entriesRef = mapRef; if (isNullable) { entriesRef = entriesRef.nullChecked; } entriesRef = entriesRef.property('toMap').call([]).property('entries'); final valueRef = refer('entry').property('value'); return Block.of([ const Code('for (var entry in '), entriesRef.code, const Code(') {'), _httpQuery( refer('entry').property('key'), valueShape, valueRef, isNullable: valueShape.isNullable(context, targetShape), ), const Code('}'), ]).wrapWithBlockIf(mapRef.notEqualTo(literalNull), isNullable); } /// The required HTTP operation overrides. Iterable get _httpOverrides sync* { // The `buildRequest` method final request = DartTypes.smithy.httpRequest.newInstance([ Method( (m) => m ..requiredParameters.add(Parameter((p) => p..name = 'b')) ..lambda = false ..body = Block.of(_httpRequestBuilder), ).closure, ]); yield Method( (m) => m ..annotations.add(DartTypes.core.override) ..returns = DartTypes.smithy.httpRequest ..name = 'buildRequest' ..requiredParameters.add( Parameter( (p) => p ..name = 'input' ..type = inputSymbol, ), ) ..body = request.code, ); // The `successCode` method var successCode = literalNum(httpTrait!.code); final responseCodeMember = httpOutputTraits.httpResponseCode; if (responseCodeMember != null) { successCode = refer('output') .nullSafeProperty(responseCodeMember.dartName(ShapeType.structure)) .ifNullThen(successCode); } yield Method( (m) => m ..annotations.add(DartTypes.core.override) ..returns = DartTypes.core.int ..name = 'successCode' ..lambda = true ..optionalParameters.add( Parameter( (p) => p ..name = 'output' ..type = outputSymbol.boxed, ), ) ..body = successCode.code, ); // The `buildOutput` method final output = outputPayload.symbol == DartTypes.smithy.unit ? refer('payload') : outputSymbol.newInstanceNamed('fromResponse', [ refer('payload'), refer('response'), ]); yield Method( (m) => m ..annotations.add(DartTypes.core.override) ..returns = outputSymbol ..name = 'buildOutput' ..requiredParameters.addAll([ Parameter( (p) => p ..name = 'payload' ..type = outputPayload.symbol, ), Parameter( (p) => p ..name = 'response' ..type = DartTypes.awsCommon.awsBaseHttpResponse, ), ]) ..body = output.code, ); // The `errorTypes` getter yield Method( (m) => m ..annotations.add(DartTypes.core.override) ..returns = DartTypes.core.list(DartTypes.smithy.smithyError) ..type = MethodType.getter ..name = 'errorTypes' ..lambda = true ..body = literalConstList([ for (final error in errorSymbols) error.constInstance, ]).code, ); // The `runtimeTypeName` getter yield Method( (m) => m ..annotations.add(DartTypes.core.override) ..returns = DartTypes.core.string ..type = MethodType.getter ..name = 'runtimeTypeName' ..lambda = true ..body = literalString(shape.shapeId.shape).code, ); final resolvedService = context.service?.resolvedService; final isAwsService = resolvedService != null; if (isAwsService) { // The standard AWS retryer yield Method( (m) => m ..annotations.add(DartTypes.core.override) ..returns = DartTypes.smithyAws.awsRetryer ..name = 'retryer' ..type = MethodType.getter ..body = DartTypes.smithyAws.awsRetryer.newInstance([]).code, ); // S3 requires a special baseUri to consider customizations if (resolvedService.sdkId == 'S3') { yield Method( (m) => m ..annotations.add(DartTypes.core.override) ..returns = DartTypes.core.uri ..name = 'baseUri' ..type = MethodType.getter ..body = Code.scope( (allocate) => ''' var baseUri = _baseUri ?? endpoint.uri; if (_s3ClientConfig.useDualStack) { baseUri = baseUri.replace( host: baseUri.host.replaceFirst(${allocate(DartTypes.core.regExp)}(r'^s3\\.'), 's3.dualstack.'), ); } if (_s3ClientConfig.useAcceleration) { baseUri = baseUri.replace( host: baseUri.host .replaceFirst(${allocate(DartTypes.core.regExp)}('\$_region\\\\.'), '') .replaceFirst(${allocate(DartTypes.core.regExp)}(r'^s3\\.'), 's3-accelerate.'), ); } return baseUri;''', ), ); } else { yield Method( (m) => m ..annotations.add(DartTypes.core.override) ..returns = DartTypes.core.uri ..name = 'baseUri' ..type = MethodType.getter ..body = refer('_baseUri') .ifNullThen(refer('endpoint').property('uri')) .code, ); } // Create an awsEndpoint field we can reuse final endpointResolverLib = context.endpointResolverLibrary.libraryUrl; final endpointResolver = refer('endpointResolver', endpointResolverLib); final sdkId = refer('sdkId', endpointResolverLib); yield Field( (m) => m ..late = true ..modifier = FieldModifier.final$ ..type = DartTypes.smithyAws.awsEndpoint ..name = '_awsEndpoint' ..assignment = endpointResolver.property('resolve').call([ sdkId, refer('_region'), ]).code, ); yield Method( (m) => m ..annotations.add(DartTypes.core.override) ..returns = DartTypes.smithy.endpoint ..name = 'endpoint' ..type = MethodType.getter ..body = refer('_awsEndpoint').property('endpoint').code, ); // Override `run` to use zone values yield Method( (m) => m ..annotations.add(DartTypes.core.override) ..returns = DartTypes.smithy.smithyOperation(outputSymbol) ..name = 'run' ..requiredParameters.add( Parameter( (p) => p ..type = inputSymbol ..name = 'input', ), ) ..optionalParameters.addAll([ Parameter( (p) => p ..type = DartTypes.awsCommon.awsHttpClient.boxed ..name = 'client' ..named = true, ), Parameter( (p) => p ..type = DartTypes.smithy.shapeId.boxed ..name = 'useProtocol' ..named = true, ), ]) ..body = DartTypes.async.runZoned .call([ Method( (m) => m ..body = refer('super').property('run').call([ refer('input'), ], { 'client': refer('client'), 'useProtocol': refer('useProtocol'), }).code, ).closure, ], { 'zoneValues': literalMap({ literalNullSafeSpread(): refer('_awsEndpoint') .property('credentialScope') .nullSafeProperty('zoneValues'), literalSpread(): literalMap({ DartTypes.awsCommon.awsHeaders.property('sdkInvocationId'): DartTypes.awsCommon.uuid(), }), }), }) .returned .statement, ); } } /// The `protocols` override getter. Field get _protocolsGetter => Field( (f) => f ..annotations.add(DartTypes.core.override) ..late = true ..modifier = FieldModifier.final$ ..type = DartTypes.core.list( DartTypes.smithy.httpProtocol( inputPayload.symbol.unboxed, inputSymbol, outputPayload.symbol.unboxed, outputSymbol, ), ) ..name = 'protocols' ..assignment = literalList([ for (final protocol in context.serviceProtocols) protocol.instantiableProtocolSymbol.newInstance([], { 'serializers': protocol.serializers(context), 'builderFactories': context.builderFactoriesRef, 'requestInterceptors': literalList( protocol.requestInterceptors(shape, context), DartTypes.smithy.httpRequestInterceptor, ).operatorAdd(refer('_requestInterceptors')), 'responseInterceptors': literalList( protocol.responseInterceptors(shape, context), DartTypes.smithy.httpResponseInterceptor, ).operatorAdd(refer('_responseInterceptors')), ...protocol.extraParameters(shape, context), }), ]).code, ); Iterable get _paginatedMethods sync* { final paginatedTraits = this.paginatedTraits; if (paginatedTraits == null) { return; } const emptyBody = Code(''); // The `getToken` method. final outputToken = paginatedTraits.outputToken; yield Method( (m) => m ..annotations.add(DartTypes.core.override) ..returns = outputToken?.symbol.boxed ?? DartTypes.core.void$ ..name = 'getToken' ..requiredParameters.add( Parameter( (p) => p ..type = outputSymbol ..name = 'output', ), ) ..lambda = outputToken != null ..body = outputToken?.buildExpression.call(refer('output')).code ?? emptyBody, ); // The `getItems` method. Expression? defaultValue; final items = paginatedTraits.items; if (items != null && items.isNullable) { final symbol = items.symbol.typeRef.rebuild( (t) => t ..isNullable = false ..types.clear(), ); defaultValue = symbol.newInstance([]); } var itemsBody = items?.buildExpression.call(refer('output')) ?? refer('output'); if (defaultValue != null) { itemsBody = itemsBody.ifNullThen(defaultValue); } yield Method( (m) => m ..annotations.add(DartTypes.core.override) ..returns = items?.symbol.unboxed ?? outputSymbol ..name = 'getItems' ..requiredParameters.add( Parameter( (p) => p ..type = outputSymbol ..name = 'output', ), ) ..lambda = true ..body = itemsBody.code, ); // The `rebuildInput` method. final inputToken = paginatedTraits.inputToken; final pageSize = paginatedTraits.pageSize; var inputTokenBuilder = inputToken?.buildExpression(refer('b')); if (inputTokenBuilder != null) { final inputTokenTarget = context.shapeFor(inputToken!.member.target); const builderTypes = [ ShapeType.set, ShapeType.list, ShapeType.map, ShapeType.structure, ]; final needsBuilder = builderTypes.contains(inputTokenTarget.getType()); final token = refer('token'); if (needsBuilder) { inputTokenBuilder = inputTokenBuilder.property('replace').call([token]); } else { inputTokenBuilder = inputTokenBuilder.assign(token); } } yield Method( (m) => m ..annotations.add(DartTypes.core.override) ..returns = inputSymbol ..name = 'rebuildInput' ..requiredParameters.addAll([ Parameter( (p) => p ..type = inputSymbol ..name = 'input', ), Parameter( (p) => p ..type = inputToken?.symbol.unboxed ?? DartTypes.core.void$ ..name = 'token', ), Parameter( (p) => p ..type = pageSize?.symbol.boxed ?? DartTypes.core.void$ ..name = 'pageSize', ), ]) ..lambda = true ..body = refer('input').property('rebuild').call([ Method( (m) => m ..requiredParameters.add(Parameter((p) => p..name = 'b')) ..body = Block.of([ if (inputToken != null && outputToken != null) inputTokenBuilder!.statement, if (pageSize != null) pageSize .buildExpression(refer('b')) .assign(refer('pageSize')) .statement .wrapWithBlockIf( refer('pageSize').notEqualTo(literalNull), pageSize.isNullable, ), ]), ).closure, ]).code, ); } }