// Copyright 2020 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
////////////////////////////////////////////////////////////////////////////////
package com.google.crypto.tink.jwt;

import static com.google.common.truth.Truth.assertThat;
import static java.nio.charset.StandardCharsets.US_ASCII;
import static java.nio.charset.StandardCharsets.UTF_8;
import static org.junit.Assert.assertThrows;
import static org.junit.Assume.assumeFalse;

import com.google.crypto.tink.CleartextKeysetHandle;
import com.google.crypto.tink.KeyTemplate;
import com.google.crypto.tink.KeyTypeManager;
import com.google.crypto.tink.KeysetHandle;
import com.google.crypto.tink.proto.JwtEcdsaAlgorithm;
import com.google.crypto.tink.proto.JwtEcdsaKeyFormat;
import com.google.crypto.tink.proto.JwtEcdsaPrivateKey;
import com.google.crypto.tink.proto.JwtEcdsaPublicKey;
import com.google.crypto.tink.proto.KeyData.KeyMaterialType;
import com.google.crypto.tink.proto.Keyset;
import com.google.crypto.tink.subtle.Base64;
import com.google.crypto.tink.subtle.EcdsaSignJce;
import com.google.crypto.tink.subtle.EllipticCurves;
import com.google.crypto.tink.subtle.EllipticCurves.EcdsaEncoding;
import com.google.crypto.tink.subtle.Enums;
import com.google.crypto.tink.subtle.Random;
import com.google.crypto.tink.testing.TestUtil;
import com.google.gson.JsonObject;
import com.google.protobuf.ByteString;
import com.google.protobuf.ExtensionRegistryLite;
import java.io.ByteArrayInputStream;
import java.security.GeneralSecurityException;
import java.security.interfaces.ECPrivateKey;
import java.util.Arrays;
import java.util.Optional;
import java.util.Set;
import java.util.TreeSet;
import junitparams.JUnitParamsRunner;
import junitparams.Parameters;
import org.junit.BeforeClass;
import org.junit.Test;
import org.junit.runner.RunWith;

/** Unit tests for JwtEcdsaSignKeyManager. */
@RunWith(JUnitParamsRunner.class)
public class JwtEcdsaSignKeyManagerTest {

  @BeforeClass
  public static void setUp() throws Exception {
    JwtSignatureConfig.register();
  }

  private final JwtEcdsaSignKeyManager manager = new JwtEcdsaSignKeyManager();
  private final KeyTypeManager.KeyFactory<JwtEcdsaKeyFormat, JwtEcdsaPrivateKey> factory =
      manager.keyFactory();

  private static JwtEcdsaKeyFormat createKeyFormat(JwtEcdsaAlgorithm algorithm) {
    return JwtEcdsaKeyFormat.newBuilder().setAlgorithm(algorithm).build();
  }

  private static Object[] parametersAlgos() {
    return new Object[] {JwtEcdsaAlgorithm.ES256, JwtEcdsaAlgorithm.ES384, JwtEcdsaAlgorithm.ES512};
  }

  private static Object[] templates() {
    return new Object[] {
      JwtEcdsaSignKeyManager.jwtES256Template(),
      JwtEcdsaSignKeyManager.jwtES384Template(),
      JwtEcdsaSignKeyManager.jwtES512Template()
    };
  }

  @Test
  public void basics() throws Exception {
    assertThat(manager.getKeyType())
        .isEqualTo("type.googleapis.com/google.crypto.tink.JwtEcdsaPrivateKey");
    assertThat(manager.getVersion()).isEqualTo(0);
    assertThat(manager.keyMaterialType()).isEqualTo(KeyMaterialType.ASYMMETRIC_PRIVATE);
  }

  @Test
  public void validateKeyFormat_empty_throw() throws Exception {
    assertThrows(
        GeneralSecurityException.class,
        () -> factory.validateKeyFormat(JwtEcdsaKeyFormat.getDefaultInstance()));
  }

  @Test
  @Parameters(method = "parametersAlgos")
  public void validateKeyFormat_ok(JwtEcdsaAlgorithm algorithm) throws GeneralSecurityException {
    JwtEcdsaKeyFormat format = createKeyFormat(algorithm);
    factory.validateKeyFormat(format);
  }

  private static void checkConsistency(JwtEcdsaPrivateKey privateKey, JwtEcdsaKeyFormat keyFormat) {
    assertThat(privateKey.getPublicKey().getAlgorithm()).isEqualTo(keyFormat.getAlgorithm());
  }

  @Test
  @Parameters(method = "parametersAlgos")
  public void createKeys_ok(JwtEcdsaAlgorithm algorithm) throws Exception {

    JwtEcdsaKeyFormat format = createKeyFormat(algorithm);
    JwtEcdsaPrivateKey key = factory.createKey(format);
    checkConsistency(key, format);
  }

  @Test
  @Parameters(method = "parametersAlgos")
  public void createKey_alwaysNewElement_ok(JwtEcdsaAlgorithm algorithm) throws Exception {

    JwtEcdsaKeyFormat format = createKeyFormat(algorithm);
    Set<String> keys = new TreeSet<>();
    // Calls newKey multiple times and make sure that they generate different keys -- takes about a
    // second per key.
    int numTests = 5;
    for (int i = 0; i < numTests; i++) {
      JwtEcdsaPrivateKey key = factory.createKey(format);
      keys.add(TestUtil.hexEncode(key.getKeyValue().toByteArray()));
    }
    assertThat(keys).hasSize(numTests);
  }

  @Test
  @Parameters(method = "parametersAlgos")
  public void getPublicKey_checkValues(JwtEcdsaAlgorithm algorithm) throws Exception {
    JwtEcdsaPrivateKey privateKey = factory.createKey(createKeyFormat(algorithm));
    JwtEcdsaPublicKey publicKey = manager.getPublicKey(privateKey);

    assertThat(publicKey).isEqualTo(privateKey.getPublicKey());
  }

  @Test
  @Parameters(method = "parametersAlgos")
  public void createCorruptedPublicKeyPrimitive_throws(JwtEcdsaAlgorithm algorithm)
      throws Exception {

    JwtEcdsaKeyFormat format = createKeyFormat(algorithm);
    JwtEcdsaPrivateKey originalKey = factory.createKey(format);
    byte[] originalPubX = originalKey.getPublicKey().getX().toByteArray();
    byte[] originalPubY = originalKey.getPublicKey().getY().toByteArray();
    originalPubX[0] = (byte) (originalPubX[0] ^ 0x01);
    ByteString corruptedPubX = ByteString.copyFrom(originalPubX);
    JwtEcdsaPublicKey corruptedPub =
        JwtEcdsaPublicKey.newBuilder()
            .setVersion(originalKey.getPublicKey().getVersion())
            .setAlgorithm(algorithm)
            .setX(corruptedPubX)
            .setY(ByteString.copyFrom(originalPubY))
            .build();
    JwtEcdsaPrivateKey corruptedKey =
        JwtEcdsaPrivateKey.newBuilder()
            .setVersion(originalKey.getVersion())
            .setPublicKey(corruptedPub)
            .setKeyValue(originalKey.getKeyValue())
            .build();
    assertThrows(
        GeneralSecurityException.class,
        () -> manager.getPrimitive(corruptedKey, JwtPublicKeySign.class));
  }

  @Test
  public void testDeriveKey_throw() throws Exception {
    assertThrows(
        UnsupportedOperationException.class,
        () ->
            factory.deriveKey(
                JwtEcdsaKeyFormat.getDefaultInstance(),
                new ByteArrayInputStream(Random.randBytes(100))));
  }

  private static void checkTemplate(KeyTemplate template, JwtEcdsaAlgorithm algorithm)
      throws Exception {
    assertThat(template.getTypeUrl()).isEqualTo(new JwtEcdsaSignKeyManager().getKeyType());
    assertThat(template.getOutputPrefixType()).isEqualTo(KeyTemplate.OutputPrefixType.RAW);
    JwtEcdsaKeyFormat format =
        JwtEcdsaKeyFormat.parseFrom(template.getValue(), ExtensionRegistryLite.getEmptyRegistry());
    assertThat(format.getAlgorithm()).isEqualTo(algorithm);
  }

  @Test
  public void testJwtES256Template_ok() throws Exception {
    KeyTemplate template = JwtEcdsaSignKeyManager.jwtES256Template();
    checkTemplate(template, JwtEcdsaAlgorithm.ES256);
  }

  @Test
  public void testJwtES384Template_ok() throws Exception {
    KeyTemplate template = JwtEcdsaSignKeyManager.jwtES384Template();
    checkTemplate(template, JwtEcdsaAlgorithm.ES384);
  }

  @Test
  public void testJwtES512Template_ok() throws Exception {
    KeyTemplate template = JwtEcdsaSignKeyManager.jwtES512Template();
    checkTemplate(template, JwtEcdsaAlgorithm.ES512);
  }

  @Test
  @Parameters(method = "templates")
  public void createSignVerify_success(KeyTemplate template) throws Exception {
    assumeFalse(TestUtil.isTsan());  // KeysetHandle.generateNew is too slow in Tsan.
    KeysetHandle handle = KeysetHandle.generateNew(template);
    JwtPublicKeySign signer = handle.getPrimitive(JwtPublicKeySign.class);
    JwtPublicKeyVerify verifier =
        handle.getPublicKeysetHandle().getPrimitive(JwtPublicKeyVerify.class);
    JwtValidator validator = new JwtValidator.Builder().build();

    RawJwt rawToken = new RawJwt.Builder().setIssuer("issuer").build();
    String signedCompact = signer.signAndEncode(rawToken);
    VerifiedJwt verifiedToken = verifier.verifyAndDecode(signedCompact, validator);
    assertThat(verifiedToken.getIssuer()).isEqualTo("issuer");
   assertThat(verifiedToken.hasTypeHeader()).isFalse();

    RawJwt rawTokenWithType = new RawJwt.Builder().setTypeHeader("typeHeader").build();
    String signedCompactWithType = signer.signAndEncode(rawTokenWithType);
    VerifiedJwt verifiedTokenWithType = verifier.verifyAndDecode(signedCompactWithType, validator);
    assertThat(verifiedTokenWithType.getTypeHeader()).isEqualTo("typeHeader");
  }

  @Test
  @Parameters(method = "templates")
  public void createSignVerifyDifferentKey_throw(KeyTemplate template) throws Exception {
    assumeFalse(TestUtil.isTsan());  // KeysetHandle.generateNew is too slow in Tsan.
    KeysetHandle handle = KeysetHandle.generateNew(template);
    JwtPublicKeySign signer = handle.getPrimitive(JwtPublicKeySign.class);
    RawJwt rawToken = new RawJwt.Builder().setIssuer("issuer").build();
    String signedCompact = signer.signAndEncode(rawToken);

    KeysetHandle otherHandle = KeysetHandle.generateNew(template);
    JwtPublicKeyVerify otherVerifier =
        otherHandle.getPublicKeysetHandle().getPrimitive(JwtPublicKeyVerify.class);
    JwtValidator validator = new JwtValidator.Builder().build();
    assertThrows(
        GeneralSecurityException.class,
        () -> otherVerifier.verifyAndDecode(signedCompact, validator));
  }

  @Test
  @Parameters(method = "templates")
  public void createSignVerify_header_modification_throw(KeyTemplate template) throws Exception {
    assumeFalse(TestUtil.isTsan());  // KeysetHandle.generateNew is too slow in Tsan.
    KeysetHandle handle = KeysetHandle.generateNew(template);
    JwtPublicKeySign signer = handle.getPrimitive(JwtPublicKeySign.class);
    JwtPublicKeyVerify verifier =
        handle.getPublicKeysetHandle().getPrimitive(JwtPublicKeyVerify.class);
    RawJwt rawToken = new RawJwt.Builder().setIssuer("issuer").build();
    String signedCompact = signer.signAndEncode(rawToken);

    // Modify the header by adding a space at the end.
    String[] parts = signedCompact.split("\\.", -1);
    String header = new String(Base64.urlSafeDecode(parts[0]), UTF_8);
    String headerBase64 = Base64.urlSafeEncode((header + " ").getBytes(UTF_8));
    String modifiedCompact = headerBase64 + "." + parts[1] + "." + parts[2];

    JwtValidator validator = new JwtValidator.Builder().build();
    assertThrows(
        GeneralSecurityException.class, () -> verifier.verifyAndDecode(modifiedCompact, validator));
  }

  @Test
  @Parameters(method = "templates")
  public void createSignVerify_payload_modification_throw(KeyTemplate template) throws Exception {
    assumeFalse(TestUtil.isTsan());  // KeysetHandle.generateNew is too slow in Tsan.
    KeysetHandle handle = KeysetHandle.generateNew(template);
    JwtPublicKeySign signer = handle.getPrimitive(JwtPublicKeySign.class);
    JwtPublicKeyVerify verifier =
        handle.getPublicKeysetHandle().getPrimitive(JwtPublicKeyVerify.class);
    RawJwt rawToken = new RawJwt.Builder().setIssuer("issuer").build();
    String signedCompact = signer.signAndEncode(rawToken);

    // Modify the payload by adding a space at the end.
    String[] parts = signedCompact.split("\\.", -1);
    String payload = new String(Base64.urlSafeDecode(parts[1]), UTF_8);
    String payloadBase64 = Base64.urlSafeEncode((payload + " ").getBytes(UTF_8));
    String modifiedCompact = parts[0] + "." + payloadBase64 + "." + parts[2];

    JwtValidator validator = new JwtValidator.Builder().build();
    assertThrows(
        GeneralSecurityException.class, () -> verifier.verifyAndDecode(modifiedCompact, validator));
  }

  @Test
  @Parameters(method = "templates")
  public void createSignVerify_bitFlipped_throw(KeyTemplate template) throws Exception {
    assumeFalse(TestUtil.isTsan());  // KeysetHandle.generateNew is too slow in Tsan.
    KeysetHandle handle = KeysetHandle.generateNew(template);
    JwtPublicKeySign signer = handle.getPrimitive(JwtPublicKeySign.class);
    JwtPublicKeyVerify verifier =
        handle.getPublicKeysetHandle().getPrimitive(JwtPublicKeyVerify.class);
    RawJwt rawToken = new RawJwt.Builder().setIssuer("issuer").build();
    String result = signer.signAndEncode(rawToken);
    JwtValidator validator = new JwtValidator.Builder().build();
    char[] validJwt = new char[result.length()];
    for (int j = 0; j < result.length(); j++) {
      validJwt[j] = result.charAt(j);
    }

    // We ignore the last byte because the bas64 decoder ignores some of the bits.
    for (int i = 0; i < result.length() - 1; ++i) {
      // Flip every bit of i-th byte.
      for (int b = 0; b < 8; ++b) {
        char[] invalidJwt = Arrays.copyOf(validJwt, result.length());
        invalidJwt[i] = (char) (validJwt[i] ^ (1 << b));
        assertThrows(
            GeneralSecurityException.class,
            () -> verifier.verifyAndDecode(new String(invalidJwt), validator));
      }
    }
  }

  private static String generateSignedCompact(
      EcdsaSignJce rawSigner, JsonObject header, JsonObject payload)
      throws GeneralSecurityException {
    String payloadBase64 = Base64.urlSafeEncode(payload.toString().getBytes(UTF_8));
    String headerBase64 = Base64.urlSafeEncode(header.toString().getBytes(UTF_8));
    String unsignedCompact = headerBase64 + "." + payloadBase64;
    String signature =
        Base64.urlSafeEncode(rawSigner.sign(unsignedCompact.getBytes(UTF_8)));
    return unsignedCompact + "." + signature;
  }

  @Test
  public void createSignVerify_withDifferentHeaders() throws Exception {
    assumeFalse(TestUtil.isTsan());  // KeysetHandle.generateNew is too slow in Tsan.
    KeyTemplate template = JwtEcdsaSignKeyManager.jwtES256Template();
    KeysetHandle handle = KeysetHandle.generateNew(template);
    Keyset keyset = CleartextKeysetHandle.getKeyset(handle);
    JwtEcdsaPrivateKey keyProto =
        JwtEcdsaPrivateKey.parseFrom(
            keyset.getKey(0).getKeyData().getValue(), ExtensionRegistryLite.getEmptyRegistry());
    ECPrivateKey privateKey =
        EllipticCurves.getEcPrivateKey(
            JwtEcdsaVerifyKeyManager.getCurve(keyProto.getPublicKey().getAlgorithm()),
            keyProto.getKeyValue().toByteArray());
    JwtEcdsaAlgorithm algorithm = keyProto.getPublicKey().getAlgorithm();
    Enums.HashType hash = JwtEcdsaVerifyKeyManager.hashForEcdsaAlgorithm(algorithm);
    EcdsaSignJce rawSigner = new EcdsaSignJce(privateKey, hash, EcdsaEncoding.IEEE_P1363);

    JsonObject payload = new JsonObject();
    payload.addProperty(JwtNames.CLAIM_ISSUER, "issuer");
    JwtValidator validator = new JwtValidator.Builder().build();
    JwtPublicKeyVerify verifier =
        handle.getPublicKeysetHandle().getPrimitive(JwtPublicKeyVerify.class);

    // Normal, valid signed compact.
    String unsignedCompact =
        JwtFormat.createUnsignedCompact(algorithm.name(), Optional.empty(), payload.toString());
    String normalSignedCompact =
        JwtFormat.createSignedCompact(
            unsignedCompact, rawSigner.sign(unsignedCompact.getBytes(US_ASCII)));
    verifier.verifyAndDecode(normalSignedCompact, validator);

    // valid token, with "typ" set in the header
    JsonObject goodHeader = new JsonObject();
    goodHeader.addProperty(JwtNames.HEADER_ALGORITHM, "ES256");
    goodHeader.addProperty("typ", "JWT");
    String goodSignedCompact = generateSignedCompact(rawSigner, goodHeader, payload);
    verifier.verifyAndDecode(goodSignedCompact, validator);

    // invalid token with an empty header
    JsonObject emptyHeader = new JsonObject();
    String emptyHeaderSignedCompact = generateSignedCompact(rawSigner, emptyHeader, payload);
    assertThrows(
        GeneralSecurityException.class,
        () -> verifier.verifyAndDecode(emptyHeaderSignedCompact, validator));

    // invalid token with a valid but incorrect algorithm in the header
    JsonObject badAlgoHeader = new JsonObject();
    badAlgoHeader.addProperty(JwtNames.HEADER_ALGORITHM, "RS256");
    String badAlgoSignedCompact = generateSignedCompact(rawSigner, badAlgoHeader, payload);
    assertThrows(
        GeneralSecurityException.class,
        () -> verifier.verifyAndDecode(badAlgoSignedCompact, validator));

    // token with an unknown "typ" in the header is valid
    JsonObject unknownTypeHeader = new JsonObject();
    unknownTypeHeader.addProperty(JwtNames.HEADER_ALGORITHM, "ES256");
    unknownTypeHeader.addProperty("typ", "unknown");
    String unknownTypeSignedCompact = generateSignedCompact(rawSigner, unknownTypeHeader, payload);
    verifier.verifyAndDecode(unknownTypeSignedCompact, validator);
  }
}
